Skip to content

Commit

Permalink
remove Context from IndexConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
haruska committed Mar 1, 2024
1 parent 948e9b7 commit 465e447
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 39 deletions.
56 changes: 25 additions & 31 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,18 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"log"
"time"
)

type IndexConnection struct {
host string
apiKey string
dataClient *data.VectorServiceClient
ctx *context.Context
ctxCancel context.CancelFunc
grpcConn *grpc.ClientConn
}

func newIndexConnection(apiKey string, host string) (*IndexConnection, error) {
config := &tls.Config{}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

ctx = metadata.AppendToOutgoingContext(ctx, "api-key", apiKey)
target := fmt.Sprintf("%s:443", host)

conn, err := grpc.DialContext(
ctx,
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(credentials.NewTLS(config)),
grpc.WithAuthority(target),
Expand All @@ -43,12 +34,11 @@ func newIndexConnection(apiKey string, host string) (*IndexConnection, error) {

dataClient := data.NewVectorServiceClient(conn)

idx := IndexConnection{host: host, dataClient: &dataClient, ctx: &ctx, ctxCancel: cancel, grpcConn: conn}
idx := IndexConnection{apiKey: apiKey, dataClient: &dataClient, grpcConn: conn}
return &idx, nil
}

func (idx *IndexConnection) Close() error {
idx.ctxCancel()
err := idx.grpcConn.Close()
return err
}
Expand All @@ -58,7 +48,7 @@ type UpsertVectorsRequest struct {
Namespace string
}

func (idx *IndexConnection) UpsertVectors(in *UpsertVectorsRequest) (uint32, error) {
func (idx *IndexConnection) UpsertVectors(ctx *context.Context, in *UpsertVectorsRequest) (uint32, error) {
vectors := make([]*data.Vector, len(in.Vectors))
for i, v := range in.Vectors {
vectors[i] = vecToGrpc(v)
Expand All @@ -69,7 +59,7 @@ func (idx *IndexConnection) UpsertVectors(in *UpsertVectorsRequest) (uint32, err
Namespace: in.Namespace,
}

res, err := (*idx.dataClient).Upsert(*idx.ctx, req)
res, err := (*idx.dataClient).Upsert(idx.akCtx(*ctx), req)
if err != nil {
return 0, err
}
Expand All @@ -87,13 +77,13 @@ type FetchVectorsResponse struct {
Usage *Usage
}

func (idx *IndexConnection) FetchVectors(in *FetchVectorsRequest) (*FetchVectorsResponse, error) {
func (idx *IndexConnection) FetchVectors(ctx *context.Context, in *FetchVectorsRequest) (*FetchVectorsResponse, error) {
req := &data.FetchRequest{
Ids: in.Ids,
Namespace: in.Namespace,
}

res, err := (*idx.dataClient).Fetch(*idx.ctx, req)
res, err := (*idx.dataClient).Fetch(idx.akCtx(*ctx), req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,14 +114,14 @@ type ListVectorsResponse struct {
NextPaginationToken *string
}

func (idx *IndexConnection) ListVectors(in *ListVectorsRequest) (*ListVectorsResponse, error) {
func (idx *IndexConnection) ListVectors(ctx *context.Context, in *ListVectorsRequest) (*ListVectorsResponse, error) {
req := &data.ListRequest{
Prefix: in.Prefix,
Limit: in.Limit,
PaginationToken: in.PaginationToken,
Namespace: in.Namespace,
}
res, err := (*idx.dataClient).List(*idx.ctx, req)
res, err := (*idx.dataClient).List(idx.akCtx(*ctx), req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -165,7 +155,7 @@ type QueryVectorsResponse struct {
Usage *Usage
}

func (idx *IndexConnection) QueryByVector(in *QueryByVectorRequest) (*QueryVectorsResponse, error) {
func (idx *IndexConnection) QueryByVector(ctx *context.Context, in *QueryByVectorRequest) (*QueryVectorsResponse, error) {
req := &data.QueryRequest{
Namespace: in.Namespace,
TopK: in.TopK,
Expand All @@ -176,7 +166,7 @@ func (idx *IndexConnection) QueryByVector(in *QueryByVectorRequest) (*QueryVecto
SparseVector: sparseValToGrpc(in.SparseValues),
}

return idx.query(req)
return idx.query(ctx, req)
}

type QueryByIdRequest struct {
Expand All @@ -189,7 +179,7 @@ type QueryByIdRequest struct {
SparseValues *SparseValues
}

func (idx *IndexConnection) QueryById(in *QueryByIdRequest) (*QueryVectorsResponse, error) {
func (idx *IndexConnection) QueryById(ctx *context.Context, in *QueryByIdRequest) (*QueryVectorsResponse, error) {
req := &data.QueryRequest{
Id: in.Id,
Namespace: in.Namespace,
Expand All @@ -200,7 +190,7 @@ func (idx *IndexConnection) QueryById(in *QueryByIdRequest) (*QueryVectorsRespon
SparseVector: sparseValToGrpc(in.SparseValues),
}

return idx.query(req)
return idx.query(ctx, req)
}

type DeleteVectorsRequest struct {
Expand All @@ -210,15 +200,15 @@ type DeleteVectorsRequest struct {
DeleteAll bool
}

func (idx *IndexConnection) DeleteVectors(in *DeleteVectorsRequest) error {
func (idx *IndexConnection) DeleteVectors(ctx *context.Context, in *DeleteVectorsRequest) error {
req := data.DeleteRequest{
Ids: in.Ids,
DeleteAll: in.DeleteAll,
Namespace: in.Namespace,
Filter: in.Filter,
}

_, err := (*idx.dataClient).Delete(*idx.ctx, &req)
_, err := (*idx.dataClient).Delete(idx.akCtx(*ctx), &req)
return err
}

Expand All @@ -230,7 +220,7 @@ type UpdateVectorRequest struct {
Namespace string
}

func (idx *IndexConnection) UpdateVector(in *UpdateVectorRequest) error {
func (idx *IndexConnection) UpdateVector(ctx *context.Context, in *UpdateVectorRequest) error {
req := &data.UpdateRequest{
Id: in.Id,
Values: in.Values,
Expand All @@ -239,7 +229,7 @@ func (idx *IndexConnection) UpdateVector(in *UpdateVectorRequest) error {
Namespace: in.Namespace,
}

_, err := (*idx.dataClient).Update(*idx.ctx, req)
_, err := (*idx.dataClient).Update(idx.akCtx(*ctx), req)
return err
}

Expand All @@ -254,11 +244,11 @@ type DescribeIndexStatsResponse struct {
Namespaces map[string]*NamespaceSummary
}

func (idx *IndexConnection) DescribeIndexStats(in *DescribeIndexStatsRequest) (*DescribeIndexStatsResponse, error) {
func (idx *IndexConnection) DescribeIndexStats(ctx *context.Context, in *DescribeIndexStatsRequest) (*DescribeIndexStatsResponse, error) {
req := &data.DescribeIndexStatsRequest{
Filter: in.Filter,
}
res, err := (*idx.dataClient).DescribeIndexStats(*idx.ctx, req)
res, err := (*idx.dataClient).DescribeIndexStats(idx.akCtx(*ctx), req)
if err != nil {
return nil, err
}
Expand All @@ -278,8 +268,8 @@ func (idx *IndexConnection) DescribeIndexStats(in *DescribeIndexStatsRequest) (*
}, nil
}

func (idx *IndexConnection) query(req *data.QueryRequest) (*QueryVectorsResponse, error) {
res, err := (*idx.dataClient).Query(*idx.ctx, req)
func (idx *IndexConnection) query(ctx *context.Context, req *data.QueryRequest) (*QueryVectorsResponse, error) {
res, err := (*idx.dataClient).Query(idx.akCtx(*ctx), req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -362,3 +352,7 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues {
Values: sv.Values,
}
}

func (idx *IndexConnection) akCtx(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey)
}
26 changes: 18 additions & 8 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pinecone

import (
"context"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -61,7 +62,8 @@ func (ts *IndexConnectionTests) TestFetchVectors() {
Namespace: ts.namespace,
}

res, err := ts.idxConn.FetchVectors(req)
ctx := context.Background()
res, err := ts.idxConn.FetchVectors(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}
Expand All @@ -73,7 +75,8 @@ func (ts *IndexConnectionTests) TestQueryByVector() {
TopK: 5,
}

res, err := ts.idxConn.QueryByVector(req)
ctx := context.Background()
res, err := ts.idxConn.QueryByVector(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}
Expand All @@ -85,7 +88,8 @@ func (ts *IndexConnectionTests) TestQueryById() {
TopK: 5,
}

res, err := ts.idxConn.QueryById(req)
ctx := context.Background()
res, err := ts.idxConn.QueryById(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}
Expand All @@ -96,15 +100,17 @@ func (ts *IndexConnectionTests) TestDeleteVectors() {
Namespace: ts.namespace,
}

err := ts.idxConn.DeleteVectors(req)
ctx := context.Background()
err := ts.idxConn.DeleteVectors(&ctx, req)
assert.NoError(ts.T(), err)

ts.loadData() //reload deleted data
}

func (ts *IndexConnectionTests) TestDescribeIndexStats() {
req := &DescribeIndexStatsRequest{}
res, err := ts.idxConn.DescribeIndexStats(req)
ctx := context.Background()
res, err := ts.idxConn.DescribeIndexStats(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}
Expand All @@ -115,7 +121,8 @@ func (ts *IndexConnectionTests) TestListVectors() {
Namespace: ts.namespace,
}

res, err := ts.idxConn.ListVectors(req)
ctx := context.Background()
res, err := ts.idxConn.ListVectors(&ctx, req)
assert.NoError(ts.T(), err)
assert.NotNil(ts.T(), res)
}
Expand All @@ -138,12 +145,15 @@ func (ts *IndexConnectionTests) loadData() {
Vectors: vectors,
Namespace: ts.namespace,
}
_, err := ts.idxConn.UpsertVectors(req)

ctx := context.Background()
_, err := ts.idxConn.UpsertVectors(&ctx, req)
assert.NoError(ts.T(), err)
}

func (ts *IndexConnectionTests) truncateData() {
err := ts.idxConn.DeleteVectors(&DeleteVectorsRequest{
ctx := context.Background()
err := ts.idxConn.DeleteVectors(&ctx, &DeleteVectorsRequest{
DeleteAll: true,
Namespace: ts.namespace,
})
Expand Down

0 comments on commit 465e447

Please sign in to comment.