Skip to content

Commit

Permalink
Add IndexIVFFlat example
Browse files Browse the repository at this point in the history
  • Loading branch information
p-ouellette committed Feb 14, 2021
1 parent 5848392 commit a5d70ac
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
6 changes: 6 additions & 0 deletions _example/flat/flat.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// Usage example for IndexFlat.
// Based on tutorial/cpp/1-Flat.cpp from the Faiss distribution.
// See https://github.com/facebookresearch/faiss/wiki/Getting-started for more
// information.
package main

import (
Expand Down Expand Up @@ -34,6 +38,8 @@ func main() {
if err != nil {
log.Fatal(err)
}
defer index.Delete()

fmt.Println("IsTrained() =", index.IsTrained())
index.Add(xb)
fmt.Println("Ntotal() =", index.Ntotal())
Expand Down
90 changes: 90 additions & 0 deletions _example/ivfflat/ivfflat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Usage example for IndexIVFFlat.
// Based on tutorial/cpp/2-IVFFlat.cpp from the Faiss distribution.
// See https://github.com/facebookresearch/faiss/wiki/Faster-search for more
// information.
package main

import (
"fmt"
"log"
"math/rand"

"github.com/DataIntelligenceCrew/go-faiss"
)

func main() {
d := 64 // dimension
nb := 100000 // database size
nq := 10000 // number of queries

xb := make([]float32, d*nb)
xq := make([]float32, d*nq)

for i := 0; i < nb; i++ {
for j := 0; j < d; j++ {
xb[i*d+j] = rand.Float32()
}
xb[i*d] += float32(i) / 1000
}

for i := 0; i < nq; i++ {
for j := 0; j < d; j++ {
xq[i*d+j] = rand.Float32()
}
xq[i*d] += float32(i) / 1000
}

index, err := faiss.IndexFactory(d, "IVF100,Flat", faiss.MetricL2)
if err != nil {
log.Fatal(err)
}
defer index.Delete()

fmt.Println("IsTrained() =", index.IsTrained())
index.Train(xb)
fmt.Println("IsTrained() =", index.IsTrained())
index.Add(xb)
fmt.Println("Ntotal() =", index.Ntotal())

k := int64(4)

// search xq

_, ids, err := index.Search(xq, k)
if err != nil {
log.Fatal(err)
}

fmt.Println("ids (last 5 results)=")
for i := int64(nq) - 5; i < int64(nq); i++ {
for j := int64(0); j < k; j++ {
fmt.Printf("%5d ", ids[i*k+j])
}
fmt.Println()
}

// retry with nprobe=10 (default is 1)

ps, err := faiss.NewParameterSpace()
if err != nil {
log.Fatal(err)
}
defer ps.Delete()

if err := ps.SetIndexParameter(index, "nprobe", 10); err != nil {
log.Fatal(err)
}

_, ids, err = index.Search(xq, k)
if err != nil {
log.Fatal(err)
}

fmt.Println("ids (last 5 results)=")
for i := int64(nq) - 5; i < int64(nq); i++ {
for j := int64(0); j < k; j++ {
fmt.Printf("%5d ", ids[i*k+j])
}
fmt.Println()
}
}

0 comments on commit a5d70ac

Please sign in to comment.