Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Toy: Limited Training Size #274

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 149 additions & 60 deletions section_faiss_vector_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,

// safe to assume that all the indexes are of the same config values, given
// that they are extracted from the field mapping info.
var dims, metric int
var dims, metric, nvecs int
var indexOptimizedFor string

var validMerge bool
Expand All @@ -308,6 +308,7 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
}
if len(vecIndexes[segI].vecIds) > 0 {
indexReconsLen := len(vecIndexes[segI].vecIds) * index.D()
nvecs += len(vecIndexes[segI].vecIds)
if indexReconsLen > reconsCap {
reconsCap = indexReconsLen
}
Expand All @@ -328,59 +329,11 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
return nil
}

finalVecIDs := make([]int64, 0, finalVecIDCap)
// merging of indexes with reconstruction method.
// the indexes[i].vecIds has only the valid vecs of this vector
// index present in it, so we'd be reconstructing only those.
indexData := make([]float32, 0, indexDataCap)
// reusable buffer for reconstruction
recons := make([]float32, 0, reconsCap)
var err error
for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

// reconstruct the vectors only if present, it could be that
// some of the indexes had all of their vectors updated/deleted.
if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D()
recons = recons[:neededReconsLen]
// todo: parallelize reconstruction
recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
indexData = append(indexData, recons...)
// Adding vector IDs in the same order as the vectors
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...)
}
}

if len(indexData) == 0 {
// no valid vectors for this index, so we don't even have to
// record it in the section
freeReconstructedIndexes(vecIndexes)
return nil
}
recons = nil

nvecs := len(finalVecIDs)

// index type to be created after merge based on the number of vectors
// in indexData added into the index.
nlist := determineCentroids(nvecs)
indexDescription, indexClass := determineIndexToUse(nvecs, nlist, indexOptimizedFor)

// freeing the reconstructed indexes immediately - waiting till the end
// to do the same is not needed because the following operations don't need
// the reconstructed ones anymore and doing so will hold up memory which can
// be detrimental while creating indexes during introduction.
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil

faissIndex, err := faiss.IndexFactory(dims, indexDescription, metric)
if err != nil {
return err
Expand All @@ -398,24 +351,160 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,

nprobe := calculateNprobe(nlist, indexOptimizedFor)
faissIndex.SetNProbe(nprobe)
}

// train the index all at once if the number of vectors is above a certain
// threshold, otherwise use a sample set to train the index and incrementally
// add the vectors after
if nvecs < 100000 {
finalVecIDs := make([]int64, 0, finalVecIDCap)
// merging of indexes with reconstruction method.
// the indexes[i].vecIds has only the valid vecs of this vector
// index present in it, so we'd be reconstructing only those.
indexData := make([]float32, 0, indexDataCap)
// reusable buffer for reconstruction
recons := make([]float32, 0, reconsCap)
var err error
for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

// reconstruct the vectors only if present, it could be that
// some of the indexes had all of their vectors updated/deleted.
if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D()
recons = recons[:neededReconsLen]
// todo: parallelize reconstruction
recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
indexData = append(indexData, recons...)
// Adding vector IDs in the same order as the vectors
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...)
}
}

if len(indexData) == 0 {
// no valid vectors for this index, so we don't even have to
// record it in the section
freeReconstructedIndexes(vecIndexes)
return nil
}

recons = nil
// freeing the reconstructed indexes immediately - waiting till the end
// to do the same is not needed because the following operations don't need
// the reconstructed ones anymore and doing so will hold up memory which can
// be detrimental while creating indexes during introduction.
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil

// train the vector index, essentially performs k-means clustering to partition
// the data space of indexData such that during the search time, we probe
// only a subset of vectors -> non-exhaustive search. could be a time
// consuming step when the indexData is large.
err = faissIndex.Train(indexData)
if indexClass == IndexTypeIVF {
// train the vector index, essentially performs k-means clustering to partition
// the data space of indexData such that during the search time, we probe
// only a subset of vectors -> non-exhaustive search. could be a time
// consuming step when the indexData is large.
err = faissIndex.Train(indexData)
if err != nil {
return err
}
}
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
}
}

err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
indexData = nil
finalVecIDs = nil
} else {
recons := make([]float32, 0, reconsCap)
curVecs := 0
vecLimit := 100000
if vecLimit < nlist*40 {
vecLimit = nlist * 40
}

finalVecIDs := make([]int64, 0, vecLimit)
indexData := make([]float32, 0, vecLimit*dims)
trained := false
var err error

for i := 0; i < len(vecIndexes); i++ {
if isClosed(closeCh) {
freeReconstructedIndexes(vecIndexes)
return seg.ErrClosed
}

if len(vecIndexes[i].vecIds) > 0 {
neededReconsLen := len(vecIndexes[i].vecIds) * dims
recons = recons[:neededReconsLen]

recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}

vecLen := len(vecIndexes[i].vecIds)
shift := 0

for curVecs+vecLen > vecLimit {
indexData = append(indexData, recons[shift*dims:(shift+vecLimit-curVecs)*dims]...)
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:(shift+vecLimit-curVecs)]...)

if !trained {
err = faissIndex.Train(indexData)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}
trained = true
}

err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
freeReconstructedIndexes(vecIndexes)
return err
}

indexData = indexData[:0]
finalVecIDs = finalVecIDs[:0]
shift += vecLimit - curVecs
vecLen -= vecLimit - curVecs
curVecs = 0
}

if vecLen != 0 {
indexData = append(indexData, recons[shift*dims:(shift+vecLen)*dims]...)
finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:shift+vecLen]...)
curVecs = len(finalVecIDs)
}
}
}

recons = nil
freeReconstructedIndexes(vecIndexes)
vecIndexes = nil
if curVecs > 0 {
if !trained {
err = faissIndex.Train(indexData)
if err != nil {
return err
}
}
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
if err != nil {
return err
}
}
indexData = nil
finalVecIDs = nil
}

indexData = nil
finalVecIDs = nil
var mergedIndexBytes []byte
mergedIndexBytes, err = faiss.WriteIndexIntoBuffer(faissIndex)
if err != nil {
Expand Down
Loading