Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix normalizeHost, allow for non-secure connections #74

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions pinecone/index_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/pinecone-io/go-pinecone/internal/useragent"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

Expand All @@ -37,19 +38,19 @@ type newIndexParameters struct {
}

func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*IndexConnection, error) {
target := normalizeHost(in.host)
target, isSecure := normalizeHost(in.host)

// configure default gRPC DialOptions
grpcOptions := []grpc.DialOption{
grpc.WithAuthority(target),
grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)),
}

// if the target includes an http:// address, don't include TLS
// otherwise we need to add transport credentials
if !strings.HasPrefix(target, "http://") {
if isSecure {
config := &tls.Config{}
grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(credentials.NewTLS(config)))
} else {
grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
}

// if we have user-provided dialOpts, append them to the defaults here
Expand Down Expand Up @@ -1104,24 +1105,26 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues {
}
}

func normalizeHost(host string) string {
func normalizeHost(host string) (string, bool) {
// default to secure unless http is specified
isSecure := true

parsedHost, err := url.Parse(host)
if err != nil {
log.Default().Printf("Failed to parse host %s: %v", host, err)
return host
return host, isSecure
}

if parsedHost.Scheme == "http" {
isSecure = false
}

// if https:// or http:// without a port, strip the scheme
// the gRPC client is not expecting a scheme so we strip that out
if parsedHost.Scheme == "https" {
host = strings.TrimPrefix(host, "https://")
} else if parsedHost.Scheme == "http" && parsedHost.Port() == "" {
} else if parsedHost.Scheme == "http" {
host = strings.TrimPrefix(host, "http://")
}

// if a port was provided leave it, otherwise we append :443
if parsedHost.Port() == "" {
host = host + ":443"
}

return host
return host, isSecure
}
32 changes: 14 additions & 18 deletions pinecone/index_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1000,33 +1000,29 @@ func TestToUsageUnit(t *testing.T) {

func TestNormalizeHostUnit(t *testing.T) {
tests := []struct {
name string
host string
expectedHost string
name string
host string
expectedHost string
expectedIsSecure bool
}{
{
name: "https:// scheme should be removed",
host: "https://this-is-my-host.io",
expectedHost: "this-is-my-host.io:443",
}, {
name: "https:// scheme with a port should be removed",
host: "https://this-is-my-host.io:33445",
expectedHost: "this-is-my-host.io:33445",
}, {
name: "http:// scheme without a port should be removed",
host: "http://this-is-my-host.io",
expectedHost: "this-is-my-host.io:443",
name: "https:// scheme should be removed",
host: "https://this-is-my-host.io",
expectedHost: "this-is-my-host.io",
expectedIsSecure: true,
}, {
name: "http:// scheme and port should be maintained",
host: "http://this-is-my-host.io:8080",
expectedHost: "http://this-is-my-host.io:8080",
name: "https:// scheme should be removed",
host: "https://this-is-my-host.io:33445",
expectedHost: "this-is-my-host.io:33445",
expectedIsSecure: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeHost(tt.host)
result, isSecure := normalizeHost(tt.host)
assert.Equal(t, tt.expectedHost, result, "Expected result to be '%s', but got '%s'", tt.expectedHost, result)
assert.Equal(t, tt.expectedIsSecure, isSecure, "Expected isSecure to be '%t', but got '%t'", tt.expectedIsSecure, isSecure)
})
}
}
Expand Down