From 29fdfbea8fb8fe1e8a303993f0d749f779eda18a Mon Sep 17 00:00:00 2001 From: Likith B Date: Mon, 18 Nov 2024 19:34:59 +0530 Subject: [PATCH] MB-63831: Toy: Limited Training Size --- section_faiss_vector_index.go | 209 ++++++++++++++++++++++++---------- 1 file changed, 149 insertions(+), 60 deletions(-) diff --git a/section_faiss_vector_index.go b/section_faiss_vector_index.go index 1c9f91a..c923f28 100644 --- a/section_faiss_vector_index.go +++ b/section_faiss_vector_index.go @@ -281,7 +281,7 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase, // safe to assume that all the indexes are of the same config values, given // that they are extracted from the field mapping info. - var dims, metric int + var dims, metric, nvecs int var indexOptimizedFor string var validMerge bool @@ -308,6 +308,7 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase, } if len(vecIndexes[segI].vecIds) > 0 { indexReconsLen := len(vecIndexes[segI].vecIds) * index.D() + nvecs += len(vecIndexes[segI].vecIds) if indexReconsLen > reconsCap { reconsCap = indexReconsLen } @@ -328,59 +329,11 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase, return nil } - finalVecIDs := make([]int64, 0, finalVecIDCap) - // merging of indexes with reconstruction method. - // the indexes[i].vecIds has only the valid vecs of this vector - // index present in it, so we'd be reconstructing only those. - indexData := make([]float32, 0, indexDataCap) - // reusable buffer for reconstruction - recons := make([]float32, 0, reconsCap) - var err error - for i := 0; i < len(vecIndexes); i++ { - if isClosed(closeCh) { - freeReconstructedIndexes(vecIndexes) - return seg.ErrClosed - } - - // reconstruct the vectors only if present, it could be that - // some of the indexes had all of their vectors updated/deleted. - if len(vecIndexes[i].vecIds) > 0 { - neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D() - recons = recons[:neededReconsLen] - // todo: parallelize reconstruction - recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons) - if err != nil { - freeReconstructedIndexes(vecIndexes) - return err - } - indexData = append(indexData, recons...) - // Adding vector IDs in the same order as the vectors - finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...) - } - } - - if len(indexData) == 0 { - // no valid vectors for this index, so we don't even have to - // record it in the section - freeReconstructedIndexes(vecIndexes) - return nil - } - recons = nil - - nvecs := len(finalVecIDs) - // index type to be created after merge based on the number of vectors // in indexData added into the index. nlist := determineCentroids(nvecs) indexDescription, indexClass := determineIndexToUse(nvecs, nlist, indexOptimizedFor) - // freeing the reconstructed indexes immediately - waiting till the end - // to do the same is not needed because the following operations don't need - // the reconstructed ones anymore and doing so will hold up memory which can - // be detrimental while creating indexes during introduction. - freeReconstructedIndexes(vecIndexes) - vecIndexes = nil - faissIndex, err := faiss.IndexFactory(dims, indexDescription, metric) if err != nil { return err @@ -398,24 +351,160 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase, nprobe := calculateNprobe(nlist, indexOptimizedFor) faissIndex.SetNProbe(nprobe) + } + + // train the index all at once if the number of vectors is above a certain + // threshold, otherwise use a sample set to train the index and incrementally + // add the vectors after + if nvecs < 100000 { + finalVecIDs := make([]int64, 0, finalVecIDCap) + // merging of indexes with reconstruction method. + // the indexes[i].vecIds has only the valid vecs of this vector + // index present in it, so we'd be reconstructing only those. + indexData := make([]float32, 0, indexDataCap) + // reusable buffer for reconstruction + recons := make([]float32, 0, reconsCap) + var err error + for i := 0; i < len(vecIndexes); i++ { + if isClosed(closeCh) { + freeReconstructedIndexes(vecIndexes) + return seg.ErrClosed + } + + // reconstruct the vectors only if present, it could be that + // some of the indexes had all of their vectors updated/deleted. + if len(vecIndexes[i].vecIds) > 0 { + neededReconsLen := len(vecIndexes[i].vecIds) * vecIndexes[i].index.D() + recons = recons[:neededReconsLen] + // todo: parallelize reconstruction + recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons) + if err != nil { + freeReconstructedIndexes(vecIndexes) + return err + } + indexData = append(indexData, recons...) + // Adding vector IDs in the same order as the vectors + finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds...) + } + } + + if len(indexData) == 0 { + // no valid vectors for this index, so we don't even have to + // record it in the section + freeReconstructedIndexes(vecIndexes) + return nil + } + + recons = nil + // freeing the reconstructed indexes immediately - waiting till the end + // to do the same is not needed because the following operations don't need + // the reconstructed ones anymore and doing so will hold up memory which can + // be detrimental while creating indexes during introduction. + freeReconstructedIndexes(vecIndexes) + vecIndexes = nil - // train the vector index, essentially performs k-means clustering to partition - // the data space of indexData such that during the search time, we probe - // only a subset of vectors -> non-exhaustive search. could be a time - // consuming step when the indexData is large. - err = faissIndex.Train(indexData) + if indexClass == IndexTypeIVF { + // train the vector index, essentially performs k-means clustering to partition + // the data space of indexData such that during the search time, we probe + // only a subset of vectors -> non-exhaustive search. could be a time + // consuming step when the indexData is large. + err = faissIndex.Train(indexData) + if err != nil { + return err + } + } + err = faissIndex.AddWithIDs(indexData, finalVecIDs) if err != nil { return err } - } - err = faissIndex.AddWithIDs(indexData, finalVecIDs) - if err != nil { - return err + indexData = nil + finalVecIDs = nil + } else { + recons := make([]float32, 0, reconsCap) + curVecs := 0 + vecLimit := 100000 + if vecLimit < nlist*40 { + vecLimit = nlist * 40 + } + + finalVecIDs := make([]int64, 0, vecLimit) + indexData := make([]float32, 0, vecLimit*dims) + trained := false + var err error + + for i := 0; i < len(vecIndexes); i++ { + if isClosed(closeCh) { + freeReconstructedIndexes(vecIndexes) + return seg.ErrClosed + } + + if len(vecIndexes[i].vecIds) > 0 { + neededReconsLen := len(vecIndexes[i].vecIds) * dims + recons = recons[:neededReconsLen] + + recons, err = vecIndexes[i].index.ReconstructBatch(vecIndexes[i].vecIds, recons) + if err != nil { + freeReconstructedIndexes(vecIndexes) + return err + } + + vecLen := len(vecIndexes[i].vecIds) + shift := 0 + + for curVecs+vecLen > vecLimit { + indexData = append(indexData, recons[shift*dims:(shift+vecLimit-curVecs)*dims]...) + finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:(shift+vecLimit-curVecs)]...) + + if !trained { + err = faissIndex.Train(indexData) + if err != nil { + freeReconstructedIndexes(vecIndexes) + return err + } + trained = true + } + + err = faissIndex.AddWithIDs(indexData, finalVecIDs) + if err != nil { + freeReconstructedIndexes(vecIndexes) + return err + } + + indexData = indexData[:0] + finalVecIDs = finalVecIDs[:0] + shift += vecLimit - curVecs + vecLen -= vecLimit - curVecs + curVecs = 0 + } + + if vecLen != 0 { + indexData = append(indexData, recons[shift*dims:(shift+vecLen)*dims]...) + finalVecIDs = append(finalVecIDs, vecIndexes[i].vecIds[shift:shift+vecLen]...) + curVecs = len(finalVecIDs) + } + } + } + + recons = nil + freeReconstructedIndexes(vecIndexes) + vecIndexes = nil + if curVecs > 0 { + if !trained { + err = faissIndex.Train(indexData) + if err != nil { + return err + } + } + err = faissIndex.AddWithIDs(indexData, finalVecIDs) + if err != nil { + return err + } + } + indexData = nil + finalVecIDs = nil } - indexData = nil - finalVecIDs = nil var mergedIndexBytes []byte mergedIndexBytes, err = faiss.WriteIndexIntoBuffer(faissIndex) if err != nil {