Skip to content

Commit

Permalink
standardize the import names for different modules
Browse files Browse the repository at this point in the history
  • Loading branch information
austin-denoble committed Oct 15, 2024
1 parent 27de308 commit fc1911b
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 83 deletions.
12 changes: 6 additions & 6 deletions pinecone/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

"github.com/pinecone-io/go-pinecone/internal/gen"
"github.com/pinecone-io/go-pinecone/internal/gen/db_control"
db_data "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest"
db_data_rest "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest"
"github.com/pinecone-io/go-pinecone/internal/gen/inference"
"github.com/pinecone-io/go-pinecone/internal/provider"
"github.com/pinecone-io/go-pinecone/internal/useragent"
Expand Down Expand Up @@ -313,7 +313,7 @@ func (c *Client) Index(in NewIndexConnParams, dialOpts ...grpc.DialOption) (*Ind
}

dbDataOptions := buildDataClientBaseOptions(*c.baseParams)
dbDataClient, err := db_data.NewClient(ensureHostHasHttps(in.Host), dbDataOptions...)
dbDataClient, err := db_data_rest.NewClient(ensureHostHasHttps(in.Host), dbDataOptions...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1550,17 +1550,17 @@ func buildInferenceBaseOptions(in NewClientBaseParams) []inference.ClientOption
return clientOptions
}

func buildDataClientBaseOptions(in NewClientBaseParams) []db_data.ClientOption {
clientOptions := []db_data.ClientOption{}
func buildDataClientBaseOptions(in NewClientBaseParams) []db_data_rest.ClientOption {
clientOptions := []db_data_rest.ClientOption{}
headerProviders := buildSharedProviderHeaders(in)

for _, provider := range headerProviders {
clientOptions = append(clientOptions, db_data.WithRequestEditorFn(provider.Intercept))
clientOptions = append(clientOptions, db_data_rest.WithRequestEditorFn(provider.Intercept))
}

// apply custom http client if provided
if in.RestClient != nil {
clientOptions = append(clientOptions, db_data.WithHTTPClient(in.RestClient))
clientOptions = append(clientOptions, db_data_rest.WithHTTPClient(in.RestClient))
}

return clientOptions
Expand Down
84 changes: 42 additions & 42 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"net/url"
"strings"

dbDataGrpc "github.com/pinecone-io/go-pinecone/internal/gen/db_data/grpc"
dbDataRest "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest"
db_data_grpc "github.com/pinecone-io/go-pinecone/internal/gen/db_data/grpc"
db_data_rest "github.com/pinecone-io/go-pinecone/internal/gen/db_data/rest"
"github.com/pinecone-io/go-pinecone/internal/useragent"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand All @@ -30,8 +30,8 @@ import (
type IndexConnection struct {
Namespace string
additionalMetadata map[string]string
restClient *dbDataRest.Client
grpcClient *dbDataGrpc.VectorServiceClient
restClient *db_data_rest.Client
grpcClient *db_data_grpc.VectorServiceClient
grpcConn *grpc.ClientConn
}

Expand All @@ -40,7 +40,7 @@ type newIndexParameters struct {
namespace string
sourceTag string
additionalMetadata map[string]string
dbDataClient *dbDataRest.Client
dbDataClient *db_data_rest.Client
}

func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*IndexConnection, error) {
Expand Down Expand Up @@ -71,7 +71,7 @@ func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*In
return nil, err
}

dataClient := dbDataGrpc.NewVectorServiceClient(conn)
dataClient := db_data_grpc.NewVectorServiceClient(conn)

idx := IndexConnection{
Namespace: in.namespace,
Expand Down Expand Up @@ -192,12 +192,12 @@ func (idx *IndexConnection) Close() error {
// log.Fatalf("Successfully upserted %d vector(s)!\n", count)
// }
func (idx *IndexConnection) UpsertVectors(ctx context.Context, in []*Vector) (uint32, error) {
vectors := make([]*dbDataGrpc.Vector, len(in))
vectors := make([]*db_data_grpc.Vector, len(in))
for i, v := range in {
vectors[i] = vecToGrpc(v)
}

req := &dbDataGrpc.UpsertRequest{
req := &db_data_grpc.UpsertRequest{
Vectors: vectors,
Namespace: idx.Namespace,
}
Expand Down Expand Up @@ -270,7 +270,7 @@ type FetchVectorsResponse struct {
// fmt.Println("No vectors found")
// }
func (idx *IndexConnection) FetchVectors(ctx context.Context, ids []string) (*FetchVectorsResponse, error) {
req := &dbDataGrpc.FetchRequest{
req := &db_data_grpc.FetchRequest{
Ids: ids,
Namespace: idx.Namespace,
}
Expand Down Expand Up @@ -378,7 +378,7 @@ type ListVectorsResponse struct {
// fmt.Printf("Found %d vector(s)\n", len(res.VectorIds))
// }
func (idx *IndexConnection) ListVectors(ctx context.Context, in *ListVectorsRequest) (*ListVectorsResponse, error) {
req := &dbDataGrpc.ListRequest{
req := &db_data_grpc.ListRequest{
Prefix: in.Prefix,
Limit: in.Limit,
PaginationToken: in.PaginationToken,
Expand Down Expand Up @@ -508,7 +508,7 @@ type QueryVectorsResponse struct {
// }
// }
func (idx *IndexConnection) QueryByVectorValues(ctx context.Context, in *QueryByVectorValuesRequest) (*QueryVectorsResponse, error) {
req := &dbDataGrpc.QueryRequest{
req := &db_data_grpc.QueryRequest{
Namespace: idx.Namespace,
TopK: in.TopK,
Filter: in.MetadataFilter,
Expand Down Expand Up @@ -598,7 +598,7 @@ type QueryByVectorIdRequest struct {
// }
// }
func (idx *IndexConnection) QueryByVectorId(ctx context.Context, in *QueryByVectorIdRequest) (*QueryVectorsResponse, error) {
req := &dbDataGrpc.QueryRequest{
req := &db_data_grpc.QueryRequest{
Id: in.VectorId,
Namespace: idx.Namespace,
TopK: in.TopK,
Expand Down Expand Up @@ -659,7 +659,7 @@ func (idx *IndexConnection) QueryByVectorId(ctx context.Context, in *QueryByVect
// log.Fatalf("Failed to delete vector with ID: %s. Error: %s\n", vectorId, err)
// }
func (idx *IndexConnection) DeleteVectorsById(ctx context.Context, ids []string) error {
req := dbDataGrpc.DeleteRequest{
req := db_data_grpc.DeleteRequest{
Ids: ids,
Namespace: idx.Namespace,
}
Expand Down Expand Up @@ -723,7 +723,7 @@ func (idx *IndexConnection) DeleteVectorsById(ctx context.Context, ids []string)
// log.Fatalf("Failed to delete vector(s) with filter: %+v. Error: %s\n", filter, err)
// }
func (idx *IndexConnection) DeleteVectorsByFilter(ctx context.Context, metadataFilter *MetadataFilter) error {
req := dbDataGrpc.DeleteRequest{
req := db_data_grpc.DeleteRequest{
Filter: metadataFilter,
Namespace: idx.Namespace,
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (idx *IndexConnection) DeleteVectorsByFilter(ctx context.Context, metadataF
// log.Fatalf("Failed to delete vectors in namespace: \"%s\". Error: %s", idxConnection.Namespace, err)
// }
func (idx *IndexConnection) DeleteAllVectorsInNamespace(ctx context.Context) error {
req := dbDataGrpc.DeleteRequest{
req := db_data_grpc.DeleteRequest{
Namespace: idx.Namespace,
DeleteAll: true,
}
Expand Down Expand Up @@ -849,7 +849,7 @@ func (idx *IndexConnection) UpdateVector(ctx context.Context, in *UpdateVectorRe
return fmt.Errorf("a vector ID plus at least one of Values, SparseValues, or Metadata must be provided to update a vector")
}

req := &dbDataGrpc.UpdateRequest{
req := &db_data_grpc.UpdateRequest{
Id: in.Id,
Values: in.Values,
SparseValues: sparseValToGrpc(in.SparseValues),
Expand Down Expand Up @@ -980,7 +980,7 @@ func (idx *IndexConnection) DescribeIndexStats(ctx context.Context) (*DescribeIn
// }
// }
func (idx *IndexConnection) DescribeIndexStatsFiltered(ctx context.Context, metadataFilter *MetadataFilter) (*DescribeIndexStatsResponse, error) {
req := &dbDataGrpc.DescribeIndexStatsRequest{
req := &db_data_grpc.DescribeIndexStatsRequest{
Filter: metadataFilter,
}
res, err := (*idx.grpcClient).DescribeIndexStats(idx.akCtx(ctx), req)
Expand Down Expand Up @@ -1012,18 +1012,18 @@ func (idx *IndexConnection) StartImport(ctx context.Context, uri string, integra
return nil, fmt.Errorf("must specify a uri to start an import")
}

var errorModeStruct *dbDataRest.ImportErrorMode
onErrorMode := pointerOrNil(dbDataRest.ImportErrorModeOnError(*errorMode))
var errorModeStruct *db_data_rest.ImportErrorMode
onErrorMode := pointerOrNil(db_data_rest.ImportErrorModeOnError(*errorMode))

if onErrorMode != nil {
errorModeStruct = &dbDataRest.ImportErrorMode{
errorModeStruct = &db_data_rest.ImportErrorMode{
OnError: onErrorMode,
}
}

intId := pointerOrNil(*integrationId)

req := dbDataRest.StartImportRequest{
req := db_data_rest.StartImportRequest{
Uri: &uri,
IntegrationId: intId,
ErrorMode: errorModeStruct,
Expand Down Expand Up @@ -1067,7 +1067,7 @@ type ListImportsResponse struct {
}

func (idx *IndexConnection) ListImports(ctx context.Context, req *ListImportsRequest) (*ListImportsResponse, error) {
params := dbDataRest.ListBulkImportsParams{
params := db_data_rest.ListBulkImportsParams{
Limit: req.Limit,
PaginationToken: req.PaginationToken,
}
Expand Down Expand Up @@ -1100,16 +1100,16 @@ func (idx *IndexConnection) CancelImport(ctx context.Context, id string) error {
}

func decodeListImportsResponse(body io.ReadCloser) (*ListImportsResponse, error) {
var listImportsResponse *dbDataRest.ListImportsResponse
var listImportsResponse *db_data_rest.ListImportsResponse
if err := json.NewDecoder(body).Decode(&listImportsResponse); err != nil {
return nil, err
}

return toListImportsResponse(listImportsResponse), nil
}

func decodeImportModel(body io.ReadCloser) (*dbDataRest.ImportModel, error) {
var importModel dbDataRest.ImportModel
func decodeImportModel(body io.ReadCloser) (*db_data_rest.ImportModel, error) {
var importModel db_data_rest.ImportModel
if err := json.NewDecoder(body).Decode(&importModel); err != nil {
return nil, err
}
Expand All @@ -1118,15 +1118,15 @@ func decodeImportModel(body io.ReadCloser) (*dbDataRest.ImportModel, error) {
}

func decodeStartImportResponse(body io.ReadCloser) (*StartImportResponse, error) {
var importResponse *dbDataRest.StartImportResponse
var importResponse *db_data_rest.StartImportResponse
if err := json.NewDecoder(body).Decode(&importResponse); err != nil {
return nil, err
}

return toImportResponse(importResponse), nil
}

func (idx *IndexConnection) query(ctx context.Context, req *dbDataGrpc.QueryRequest) (*QueryVectorsResponse, error) {
func (idx *IndexConnection) query(ctx context.Context, req *db_data_grpc.QueryRequest) (*QueryVectorsResponse, error) {
res, err := (*idx.grpcClient).Query(idx.akCtx(ctx), req)
if err != nil {
return nil, err
Expand All @@ -1144,7 +1144,7 @@ func (idx *IndexConnection) query(ctx context.Context, req *dbDataGrpc.QueryRequ
}, nil
}

func (idx *IndexConnection) delete(ctx context.Context, req *dbDataGrpc.DeleteRequest) error {
func (idx *IndexConnection) delete(ctx context.Context, req *db_data_grpc.DeleteRequest) error {
_, err := (*idx.grpcClient).Delete(idx.akCtx(ctx), req)
return err
}
Expand All @@ -1159,7 +1159,7 @@ func (idx *IndexConnection) akCtx(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, newMetadata...)
}

func toVector(vector *dbDataGrpc.Vector) *Vector {
func toVector(vector *db_data_grpc.Vector) *Vector {
if vector == nil {
return nil
}
Expand All @@ -1171,11 +1171,11 @@ func toVector(vector *dbDataGrpc.Vector) *Vector {
}
}

func toScoredVector(sv *dbDataGrpc.ScoredVector) *ScoredVector {
func toScoredVector(sv *db_data_grpc.ScoredVector) *ScoredVector {
if sv == nil {
return nil
}
v := toVector(&dbDataGrpc.Vector{
v := toVector(&db_data_grpc.Vector{
Id: sv.Id,
Values: sv.Values,
SparseValues: sv.SparseValues,
Expand All @@ -1187,7 +1187,7 @@ func toScoredVector(sv *dbDataGrpc.ScoredVector) *ScoredVector {
}
}

func toSparseValues(sv *dbDataGrpc.SparseValues) *SparseValues {
func toSparseValues(sv *db_data_grpc.SparseValues) *SparseValues {
if sv == nil {
return nil
}
Expand All @@ -1197,7 +1197,7 @@ func toSparseValues(sv *dbDataGrpc.SparseValues) *SparseValues {
}
}

func toUsage(u *dbDataGrpc.Usage) *Usage {
func toUsage(u *db_data_grpc.Usage) *Usage {
if u == nil {
return nil
}
Expand All @@ -1206,21 +1206,21 @@ func toUsage(u *dbDataGrpc.Usage) *Usage {
}
}

func toPaginationTokenGrpc(p *dbDataGrpc.Pagination) *string {
func toPaginationTokenGrpc(p *db_data_grpc.Pagination) *string {
if p == nil {
return nil
}
return &p.Next
}

func toPaginationTokenRest(p *dbDataRest.Pagination) *string {
func toPaginationTokenRest(p *db_data_rest.Pagination) *string {
if p == nil {
return nil
}
return p.Next
}

func toImport(importModel *dbDataRest.ImportModel) *Import {
func toImport(importModel *db_data_rest.ImportModel) *Import {
if importModel == nil {
return nil
}
Expand All @@ -1235,7 +1235,7 @@ func toImport(importModel *dbDataRest.ImportModel) *Import {
}
}

func toImportResponse(importResponse *dbDataRest.StartImportResponse) *StartImportResponse {
func toImportResponse(importResponse *db_data_rest.StartImportResponse) *StartImportResponse {
if importResponse == nil {
return nil
}
Expand All @@ -1245,7 +1245,7 @@ func toImportResponse(importResponse *dbDataRest.StartImportResponse) *StartImpo
}
}

func toListImportsResponse(listImportsResponse *dbDataRest.ListImportsResponse) *ListImportsResponse {
func toListImportsResponse(listImportsResponse *db_data_rest.ListImportsResponse) *ListImportsResponse {
if listImportsResponse == nil {
return nil
}
Expand All @@ -1261,23 +1261,23 @@ func toListImportsResponse(listImportsResponse *dbDataRest.ListImportsResponse)
}
}

func vecToGrpc(v *Vector) *dbDataGrpc.Vector {
func vecToGrpc(v *Vector) *db_data_grpc.Vector {
if v == nil {
return nil
}
return &dbDataGrpc.Vector{
return &db_data_grpc.Vector{
Id: v.Id,
Values: v.Values,
Metadata: v.Metadata,
SparseValues: sparseValToGrpc(v.SparseValues),
}
}

func sparseValToGrpc(sv *SparseValues) *dbDataGrpc.SparseValues {
func sparseValToGrpc(sv *SparseValues) *db_data_grpc.SparseValues {
if sv == nil {
return nil
}
return &dbDataGrpc.SparseValues{
return &db_data_grpc.SparseValues{
Indices: sv.Indices,
Values: sv.Values,
}
Expand Down
Loading

0 comments on commit fc1911b

Please sign in to comment.