Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KNN Operator #1908

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions index_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr
ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey,
search.GeoBufferPoolCallbackFunc(getBufferPool))

// Using a disjunction query to get union of results from KNN query
// and the original query
searchQuery := disjunctQueryWithKNN(req)
// Using a query to get results from KNN queries
// and the original query based on the KNN operator.
searchQuery, err := queryWithKNN(req)
if err != nil {
return nil, err
}

searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{
Explain: req.Explain,
Expand Down
36 changes: 30 additions & 6 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/blevesearch/bleve/v2/search/query"
)

type knnOperator string

type SearchRequest struct {
Query query.Query `json:"query"`
Size int `json:"size"`
Expand All @@ -40,7 +42,8 @@ type SearchRequest struct {
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`

KNN []*KNNRequest `json:"knn"`
KNN []*KNNRequest `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`

sortFunc func(sort.Interface)
}
Expand All @@ -62,6 +65,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl
})
}

func (r *SearchRequest) AddKNNOperator(operator knnOperator) {
r.KNNOperator = operator
}

// UnmarshalJSON deserializes a JSON representation of
// a SearchRequest
func (r *SearchRequest) UnmarshalJSON(input []byte) error {
Expand All @@ -79,6 +86,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
SearchAfter []string `json:"search_after"`
SearchBefore []string `json:"search_before"`
KNN []*KNNRequest `json:"knn"`
KNNOperator knnOperator `json:"knn_operator"`
}

err := json.Unmarshal(input, &temp)
Expand Down Expand Up @@ -121,6 +129,10 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error {
}

r.KNN = temp.KNN
r.KNNOperator = temp.KNNOperator
if r.KNNOperator == "" {
r.KNNOperator = knnOperatorOr
}

return nil

Expand All @@ -143,26 +155,38 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
SearchAfter: req.SearchAfter,
SearchBefore: req.SearchBefore,
KNN: req.KNN,
KNNOperator: req.KNNOperator,
}
return &rv

}

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
var (
knnOperatorAnd = knnOperator("and")
knnOperatorOr = knnOperator("or")
)

func queryWithKNN(req *SearchRequest) (query.Query, error) {
if len(req.KNN) > 0 {
disjuncts := []query.Query{req.Query}
subQueries := []query.Query{req.Query}
for _, knn := range req.KNN {
if knn != nil {
knnQuery := query.NewKNNQuery(knn.Vector)
knnQuery.SetFieldVal(knn.Field)
knnQuery.SetK(knn.K)
knnQuery.SetBoost(knn.Boost.Value())
disjuncts = append(disjuncts, knnQuery)
subQueries = append(subQueries, knnQuery)
}
}
return query.NewDisjunctionQuery(disjuncts)
if req.KNNOperator == knnOperatorAnd {
Thejas-bhat marked this conversation as resolved.
Show resolved Hide resolved
return query.NewConjunctionQuery(subQueries), nil
} else if req.KNNOperator == knnOperatorOr || req.KNNOperator == "" {
return query.NewDisjunctionQuery(subQueries), nil
} else {
return nil, fmt.Errorf("unknown knn operator: %s", req.KNNOperator)
}
}
return req.Query
return req.Query, nil
}

func validateKNN(req *SearchRequest) error {
Expand Down
155 changes: 155 additions & 0 deletions search_knn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) 2023 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors && vectors
// +build vectors,vectors

package bleve

import (
"fmt"
"log"
"math/rand"
"strconv"
"testing"

"github.com/blevesearch/bleve/v2/mapping"
"github.com/blevesearch/bleve/v2/search/query"
)

func getRandomQueryVec(dims int) []float32 {
vec := make([]float32, dims)
for i := 0; i < dims; i++ {
vec[i] = rand.Float32()
}
return vec
}

// Test to see if KNN Operators get added right to the query.
func TestKNNOperator(t *testing.T) {
tmpIndexPath := createTmpIndexPath(t)
defer cleanupTmpIndexPath(t, tmpIndexPath)

dataset := make([]map[string]interface{}, 10)

// Indexing just a few docs to populate index.
for i := 0; i < 10; i++ {
docVec := []float32{}
for i := 0; i < 5; i++ {
docVec = append(docVec, rand.Float32())
}
dataset = append(dataset, map[string]interface{}{
"type": "vectorStuff",
"content": strconv.Itoa(i),
"vector": docVec,
})
}

indexMapping := NewIndexMapping()
indexMapping.TypeField = "type"
indexMapping.DefaultAnalyzer = "en"
documentMapping := NewDocumentMapping()
indexMapping.AddDocumentMapping("vectorStuff", documentMapping)

contentFieldMapping := NewTextFieldMapping()
contentFieldMapping.Index = true
contentFieldMapping.Store = true
documentMapping.AddFieldMappingsAt("content", contentFieldMapping)

vecFieldMapping := mapping.NewVectorFieldMapping()
vecFieldMapping.Index = true
vecFieldMapping.Dims = 5
vecFieldMapping.Similarity = "dot_product"
documentMapping.AddFieldMappingsAt("vector", vecFieldMapping)

index, err := New(tmpIndexPath, indexMapping)
if err != nil {
log.Fatal(err)
}
defer func() {
err := index.Close()
if err != nil {
log.Fatal(err)
}
}()

batch := index.NewBatch()
for i := 0; i < len(dataset); i++ {
batch.Index(strconv.Itoa(i), dataset[i])
}

err = index.Batch(batch)
if err != nil {
log.Fatal(err)
}

termQuery := query.NewTermQuery("2")

searchRequest := NewSearchRequest(termQuery)
queryVec2 := getRandomQueryVec(5)
searchRequest.AddKNN("vector", queryVec2, 3, 2.0)
searchRequest.AddKNN("vector", queryVec2, 2, 1.5)
searchRequest.Fields = []string{"content", "vector"}

// Conjunction
searchRequest.AddKNNOperator(knnOperatorAnd)
conjunction, err := queryWithKNN(searchRequest)
if err != nil {
log.Fatal(fmt.Errorf("unexpected error for AND knn operator"))
}

conj, ok := conjunction.(*query.ConjunctionQuery)
if !ok {
log.Fatal(fmt.Errorf("expected conjunction query"))
}

if len(conj.Conjuncts) == 3 {
_, ok := conj.Conjuncts[0].(*query.TermQuery)
if !ok {
log.Fatal(fmt.Errorf("expected first query to be a term query,"+
" but it's %T", conj.Conjuncts[0]))
}
} else {
log.Fatal(fmt.Errorf("expected 3 conjuncts"))
}

// Disjunction
searchRequest.AddKNNOperator(knnOperatorOr)
disjunction, err := queryWithKNN(searchRequest)
if err != nil {
log.Fatal(fmt.Errorf("unexpected error for OR knn operator"))
}

disj, ok := disjunction.(*query.DisjunctionQuery)
if !ok {
log.Fatal(fmt.Errorf("expected disjunction query"))
}

if len(disj.Disjuncts) == 3 {
_, ok := disj.Disjuncts[0].(*query.TermQuery)
if !ok {
log.Fatal(fmt.Errorf("expected first query to be a term query,"+
" but it's %T", conj.Conjuncts[0]))
}
} else {
log.Fatal(fmt.Errorf("expected 3 disjuncts"))
}

// Incorrect operator.
searchRequest.AddKNNOperator("bs_op")
searchRequest.Query, err = queryWithKNN(searchRequest)
if err == nil {
log.Fatal(fmt.Errorf("expected error for incorrect knn operator"))
}
}
4 changes: 2 additions & 2 deletions search_no_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func copySearchRequest(req *SearchRequest) *SearchRequest {
return &rv
}

func disjunctQueryWithKNN(req *SearchRequest) query.Query {
return req.Query
func queryWithKNN(req *SearchRequest) (query.Query, error) {
return req.Query, nil
}

func validateKNN(req *SearchRequest) error {
Expand Down
Loading