Skip to content

Commit

Permalink
update pinecone local test suite to test across a large dimension of …
Browse files Browse the repository at this point in the history
…index and namespace combinations, make sure we're using sparse vectors and attaching metadata, cover more deletion cases when the suite cleans up
  • Loading branch information
austin-denoble committed Sep 19, 2024
1 parent 165e7f4 commit be98dd7
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 49 deletions.
6 changes: 3 additions & 3 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (ts *IntegrationTests) TestDeleteVectorsById() {
assert.NoError(ts.T(), err)
ts.vectorIds = []string{}

vectors := GenerateVectors(5, ts.dimension, true)
vectors := GenerateVectors(5, ts.dimension, true, nil)

_, err = ts.idxConn.UpsertVectors(ctx, vectors)
if err != nil {
Expand Down Expand Up @@ -96,7 +96,7 @@ func (ts *IntegrationTests) TestDeleteVectorsByFilter() {
}
ts.vectorIds = []string{}

vectors := GenerateVectors(5, ts.dimension, true)
vectors := GenerateVectors(5, ts.dimension, true, nil)

_, err = ts.idxConn.UpsertVectors(ctx, vectors)
if err != nil {
Expand All @@ -117,7 +117,7 @@ func (ts *IntegrationTests) TestDeleteAllVectorsInNamespace() {
assert.NoError(ts.T(), err)
ts.vectorIds = []string{}

vectors := GenerateVectors(5, ts.dimension, true)
vectors := GenerateVectors(5, ts.dimension, true, nil)

_, err = ts.idxConn.UpsertVectors(ctx, vectors)
if err != nil {
Expand Down
119 changes: 82 additions & 37 deletions pinecone/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/structpb"
)

type LocalIntegrationTests struct {
Expand All @@ -22,44 +23,55 @@ type LocalIntegrationTests struct {
host string
dimension int32
indexType string
namespace string
metadata *Metadata
vectorIds []string
idxConn *IndexConnection
idxConns []*IndexConnection
}

func (ts *LocalIntegrationTests) SetupSuite() {
ctx := context.Background()

// Deterministically create vectors
vectors := GenerateVectors(100, ts.dimension, false)
vectors := GenerateVectors(100, ts.dimension, true, ts.metadata)

// Upsert vectors
upsertedVectors, err := ts.idxConn.UpsertVectors(ctx, vectors)
require.NoError(ts.T(), err)
fmt.Printf("Upserted vectors: %v into host: %s\n", upsertedVectors, ts.host)

// Add vector ids to the suite
// Get vector ids for the suite
vectorIds := make([]string, len(vectors))
for i, v := range vectors {
vectorIds[i] = v.Id
}

// Upsert vectors into each index connection
for _, idxConn := range ts.idxConns {
upsertedVectors, err := idxConn.UpsertVectors(ctx, vectors)
require.NoError(ts.T(), err)
fmt.Printf("Upserted vectors: %v into host: %s in namespace: %s \n", upsertedVectors, ts.host, idxConn.Namespace)
}

ts.vectorIds = append(ts.vectorIds, vectorIds...)
}

func (ts *LocalIntegrationTests) TearDownSuite() {
// test deleting vectors as a part of cleanup
err := ts.idxConn.DeleteVectorsById(context.Background(), ts.vectorIds)
require.NoError(ts.T(), err)
// test deleting vectors as a part of cleanup for each index connection
for _, idxConn := range ts.idxConns {
// Delete a slice of vectors by id
err := idxConn.DeleteVectorsById(context.Background(), ts.vectorIds[10:20])
require.NoError(ts.T(), err)

description, err := ts.idxConn.DescribeIndexStats(context.Background())
require.NoError(ts.T(), err)
// Delete vectors by filter

assert.NotNil(ts.T(), description, "Index description should not be nil")
assert.Equal(ts.T(), uint32(0), description.TotalVectorCount, "Total vector count should be 0 after deleting")
// Delete all remaining vectors

description, err := idxConn.DescribeIndexStats(context.Background())
require.NoError(ts.T(), err)
assert.NotNil(ts.T(), description, "Index description should not be nil")
assert.Equal(ts.T(), uint32(0), description.TotalVectorCount, "Total vector count should be 0 after deleting")
}
}

// This is the entry point for all local integration tests
// This test function is picked up by go test and triggers the suite runs when
// the
// the build tag localServer is set
func TestRunLocalIntegrationSuite(t *testing.T) {
fmt.Println("Running local integration tests")
RunLocalSuite(t)
Expand All @@ -79,30 +91,55 @@ func RunLocalSuite(t *testing.T) {
parsedDimension, err := strconv.ParseInt(dimension, 10, 32)
require.NoError(t, err)

namespace := "test-namespace"
metadata := &structpb.Struct{
Fields: map[string]*structpb.Value{
"genre": {Kind: &structpb.Value_StringValue{StringValue: "classical"}},
},
}

client, err := NewClientBase(NewClientBaseParams{})
require.NotNil(t, client, "Client should not be nil after creation")
require.NoError(t, err)

// Create index connections for pod and serverless indexes with both default namespace
// and a custom namespace
var podIdxConns []*IndexConnection
idxConnPod, err := client.Index(NewIndexConnParams{Host: localHostPod})
require.NoError(t, err)
podIdxConns = append(podIdxConns, idxConnPod)

idxConnPodNamespace, err := client.Index(NewIndexConnParams{Host: localHostPod, Namespace: namespace})
require.NoError(t, err)
podIdxConns = append(podIdxConns, idxConnPodNamespace)

var serverlessIdxConns []*IndexConnection
idxConnServerless, err := client.Index(NewIndexConnParams{Host: localHostServerless},
grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
serverlessIdxConns = append(serverlessIdxConns, idxConnServerless)

idxConnServerless, err = client.Index(NewIndexConnParams{Host: localHostServerless, Namespace: namespace})
require.NoError(t, err)
serverlessIdxConns = append(serverlessIdxConns, idxConnServerless)

localHostPodSuite := &LocalIntegrationTests{
client: client,
idxConn: idxConnPod,
idxConns: podIdxConns,
indexType: "pods",
host: localHostPod,
namespace: namespace,
metadata: metadata,
dimension: int32(parsedDimension),
}

localHostSuiteServerless := &LocalIntegrationTests{
client: client,
idxConn: idxConnServerless,
idxConns: serverlessIdxConns,
indexType: "serverless",
host: localHostServerless,
namespace: namespace,
metadata: metadata,
dimension: int32(parsedDimension),
}

Expand All @@ -113,42 +150,50 @@ func RunLocalSuite(t *testing.T) {
func (ts *LocalIntegrationTests) TestFetchVectors() {
fetchVectorId := ts.vectorIds[0]

fetchVectorsResponse, err := ts.idxConn.FetchVectors(context.Background(), []string{fetchVectorId})
require.NoError(ts.T(), err)
for _, idxConn := range ts.idxConns {
fetchVectorsResponse, err := idxConn.FetchVectors(context.Background(), []string{fetchVectorId})
require.NoError(ts.T(), err)

assert.NotNil(ts.T(), fetchVectorsResponse, "Fetch vectors response should not be nil")
assert.Equal(ts.T(), 1, len(fetchVectorsResponse.Vectors), "Fetch vectors response should have 1 vector")
assert.Equal(ts.T(), fetchVectorId, fetchVectorsResponse.Vectors[fetchVectorId].Id, "Fetched vector id should match")
assert.NotNil(ts.T(), fetchVectorsResponse, "Fetch vectors response should not be nil")
assert.Equal(ts.T(), 1, len(fetchVectorsResponse.Vectors), "Fetch vectors response should have 1 vector")
assert.Equal(ts.T(), fetchVectorId, fetchVectorsResponse.Vectors[fetchVectorId].Id, "Fetched vector id should match")
}
}

func (ts *LocalIntegrationTests) TestQueryVectors() {
queryVectorId := ts.vectorIds[0]
topK := 10

queryVectorsByIdResponse, err := ts.idxConn.QueryByVectorId(context.Background(), &QueryByVectorIdRequest{VectorId: queryVectorId, TopK: uint32(topK)})
require.NoError(ts.T(), err)
for _, idxConn := range ts.idxConns {
queryVectorsByIdResponse, err := idxConn.QueryByVectorId(context.Background(), &QueryByVectorIdRequest{VectorId: queryVectorId, TopK: uint32(topK)})
require.NoError(ts.T(), err)

assert.NotNil(ts.T(), queryVectorsByIdResponse, "Query results should not be nil")
assert.Equal(ts.T(), topK, len(queryVectorsByIdResponse.Matches), "Query results should have 10 matches")
assert.Equal(ts.T(), queryVectorId, queryVectorsByIdResponse.Matches[0].Vector.Id, "Top query result vector id should match queryVectorId")
assert.NotNil(ts.T(), queryVectorsByIdResponse, "Query results should not be nil")
assert.Equal(ts.T(), topK, len(queryVectorsByIdResponse.Matches), "Query results should have 10 matches")
assert.Equal(ts.T(), queryVectorId, queryVectorsByIdResponse.Matches[0].Vector.Id, "Top query result vector id should match queryVectorId")
}
}

func (ts *LocalIntegrationTests) TestUpdateVectors() {
updateVectorId := ts.vectorIds[0]
newValues := generateVectorValues(ts.dimension)

err := ts.idxConn.UpdateVector(context.Background(), &UpdateVectorRequest{Id: updateVectorId, Values: newValues})
require.NoError(ts.T(), err)
for _, idxConn := range ts.idxConns {
err := idxConn.UpdateVector(context.Background(), &UpdateVectorRequest{Id: updateVectorId, Values: newValues})
require.NoError(ts.T(), err)

fetchVectorsResponse, err := ts.idxConn.FetchVectors(context.Background(), []string{updateVectorId})
require.NoError(ts.T(), err)
assert.Equal(ts.T(), newValues, fetchVectorsResponse.Vectors[updateVectorId].Values, "Updated vector values should match")
fetchVectorsResponse, err := idxConn.FetchVectors(context.Background(), []string{updateVectorId})
require.NoError(ts.T(), err)
assert.Equal(ts.T(), newValues, fetchVectorsResponse.Vectors[updateVectorId].Values, "Updated vector values should match")
}
}

func (ts *LocalIntegrationTests) TestDescribeIndexStats() {
description, err := ts.idxConn.DescribeIndexStats(context.Background())
require.NoError(ts.T(), err)
for _, idxConn := range ts.idxConns {
description, err := idxConn.DescribeIndexStats(context.Background())
require.NoError(ts.T(), err)

assert.NotNil(ts.T(), description, "Index description should not be nil")
assert.Equal(ts.T(), description.TotalVectorCount, uint32(len(ts.vectorIds)), "Index host should match")
assert.NotNil(ts.T(), description, "Index description should not be nil")
assert.Equal(ts.T(), description.TotalVectorCount, uint32(len(ts.vectorIds)), "Index host should match")
}
}
13 changes: 4 additions & 9 deletions pinecone/test_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"math/rand"
"time"

"google.golang.org/protobuf/types/known/structpb"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -47,7 +45,7 @@ func (ts *IntegrationTests) SetupSuite() {
ts.idxConn = idxConn

// Deterministically create vectors
vectors := GenerateVectors(10, ts.dimension, false)
vectors := GenerateVectors(10, ts.dimension, false, nil)

// Add vector ids to the suite
vectorIds := make([]string, len(vectors))
Expand Down Expand Up @@ -157,7 +155,7 @@ func WaitUntilIndexReady(ts *IntegrationTests, ctx context.Context) (bool, error
}
}

func GenerateVectors(numOfVectors int, dimension int32, isSparse bool) []*Vector {
func GenerateVectors(numOfVectors int, dimension int32, isSparse bool, metadata *Metadata) []*Vector {
vectors := make([]*Vector, numOfVectors)

for i := 0; i < int(numOfVectors); i++ {
Expand All @@ -176,12 +174,9 @@ func GenerateVectors(numOfVectors int, dimension int32, isSparse bool) []*Vector
vectors[i].SparseValues = &sparseValues
}

metadata := &structpb.Struct{
Fields: map[string]*structpb.Value{
"genre": {Kind: &structpb.Value_StringValue{StringValue: "classical"}},
},
if metadata != nil {
vectors[i].Metadata = metadata
}
vectors[i].Metadata = metadata
}

return vectors
Expand Down

0 comments on commit be98dd7

Please sign in to comment.