forked from DataIntelligenceCrew/go-faiss
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsearch_params.go
173 lines (149 loc) · 4.85 KB
/
search_params.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package faiss
/*
#include <faiss/c_api/Index_c.h>
#include <faiss/c_api/IndexIVF_c.h>
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
*/
import "C"
import (
"encoding/json"
"fmt"
)
type SearchParams struct {
sp *C.FaissSearchParameters
}
// Delete frees the memory associated with s.
func (s *SearchParams) Delete() {
if s == nil || s.sp == nil {
return
}
C.faiss_SearchParameters_free(s.sp)
}
type searchParamsIVF struct {
NprobePct float32 `json:"ivf_nprobe_pct,omitempty"`
MaxCodesPct float32 `json:"ivf_max_codes_pct,omitempty"`
}
// IVF Parameters used to override the index-time defaults for a specific query.
// Serve as the 'new' defaults for this query, unless overridden by search-time
// params.
type defaultSearchParamsIVF struct {
Nprobe int `json:"ivf_nprobe,omitempty"`
Nlist int `json:"ivf_nlist,omitempty"`
Nvecs int `json:"ivf_nvecs,omitempty"`
}
func (s *searchParamsIVF) Validate() error {
if s.NprobePct < 0 || s.NprobePct > 100 {
return fmt.Errorf("invalid IVF search params, ivf_nprobe_pct:%v, "+
"should be in range [0, 100]", s.NprobePct)
}
if s.MaxCodesPct < 0 || s.MaxCodesPct > 100 {
return fmt.Errorf("invalid IVF search params, ivf_max_codes_pct:%v, "+
"should be in range [0, 100]", s.MaxCodesPct)
}
return nil
}
func getNProbeFromSearchParams(params *SearchParams) int32 {
return int32(C.faiss_SearchParametersIVF_nprobe(params.sp))
}
func NewSearchParamsIVF(idx Index, params json.RawMessage, sel *C.FaissIDSelector,
defaultParams defaultSearchParamsIVF) (*SearchParams, error) {
rv := &SearchParams{}
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil {
rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp)
if len(params) == 0 && sel == nil {
return rv, nil
}
var nprobe, maxCodes, nlist int
nlist = int(C.faiss_IndexIVF_nlist(ivfIdx))
// It's important to set nprobe to the value decided at the time of
// index creation. Otherwise, nprobe will be set to the default
// value of 1.
nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx))
nvecs := idx.Ntotal()
if defaultParams.Nlist > 0 {
nlist = defaultParams.Nlist
}
if defaultParams.Nprobe > 0 {
nprobe = defaultParams.Nprobe
}
var ivfParams searchParamsIVF
if len(params) > 0 {
if err := json.Unmarshal(params, &ivfParams); err != nil {
return rv, fmt.Errorf("failed to unmarshal IVF search params, "+
"err:%v", err)
}
if err := ivfParams.Validate(); err != nil {
return rv, err
}
}
if ivfParams.NprobePct > 0 {
// in the situation when the calculated nprobe happens to be
// between 0 and 1, we'll round it up.
nprobe = max(int(float32(nlist)*(ivfParams.NprobePct/100)), 1)
}
if ivfParams.MaxCodesPct > 0 {
maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100))
} // else, maxCodes will be set to the default value of 0, which means no limit
if c := C.faiss_SearchParametersIVF_new_with(
&rv.sp,
sel,
C.size_t(nprobe),
C.size_t(maxCodes),
); c != 0 {
return rv, fmt.Errorf("failed to create faiss IVF search params")
}
}
return rv, nil
}
// Always return a valid SearchParams object,
// thus caller must clean up the object
// by invoking Delete() method, even if an error is returned.
func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector,
) (*SearchParams, error) {
rv := &SearchParams{}
if c := C.faiss_SearchParameters_new(&rv.sp, sel); c != 0 {
return rv, fmt.Errorf("failed to create faiss search params")
}
// check if the index is IVF and set the search params
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx != nil {
rv.sp = C.faiss_SearchParametersIVF_cast(rv.sp)
if len(params) == 0 && sel == nil {
return rv, nil
}
var ivfParams searchParamsIVF
if len(params) > 0 {
if err := json.Unmarshal(params, &ivfParams); err != nil {
return rv, fmt.Errorf("failed to unmarshal IVF search params, "+
"err:%v", err)
}
if err := ivfParams.Validate(); err != nil {
return rv, err
}
}
var nprobe, maxCodes int
if ivfParams.NprobePct > 0 {
nlist := float32(C.faiss_IndexIVF_nlist(ivfIdx))
// in the situation when the calculated nprobe happens to be
// between 0 and 1, we'll round it up.
nprobe = max(int(nlist*(ivfParams.NprobePct/100)), 1)
} else {
// it's important to set nprobe to the value decided at the time of
// index creation. Otherwise, nprobe will be set to the default
// value of 1.
nprobe = int(C.faiss_IndexIVF_nprobe(ivfIdx))
}
if ivfParams.MaxCodesPct > 0 {
nvecs := C.faiss_Index_ntotal(idx.cPtr())
maxCodes = int(float32(nvecs) * (ivfParams.MaxCodesPct / 100))
} // else, maxCodes will be set to the default value of 0, which means no limit
if c := C.faiss_SearchParametersIVF_new_with(
&rv.sp,
sel,
C.size_t(nprobe),
C.size_t(maxCodes),
); c != 0 {
return rv, fmt.Errorf("failed to create faiss IVF search params")
}
}
return rv, nil
}