diff --git a/dialer_test.go b/dialer_test.go index 07f18bdd..9da624d8 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1167,3 +1167,76 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { ) } + +func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) { + + // Create an instance with custom SAN 'db.example.com' + inst := mock.NewFakeCSQLInstanceWithSan( + "my-project", "my-region", "my-instance", []string{"db.example.com"}, + mock.WithDNS("db.example.com"), + mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"), + ) + + wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithResolver(&fakeResolver{ + entries: map[string]instance.ConnName{ + "db.example.com": wantName, + }, + }), + }, + }) + + // Dial db.example.com + testSuccessfulDial( + context.Background(), t, d, + "db.example.com", + ) +} + +func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) { + + // Create an instance with custom SAN 'db.example.com' + inst := mock.NewFakeCSQLInstanceWithSan( + "my-project", "my-region", "my-instance", []string{"db.example.com"}, + mock.WithDNS("db.example.com"), + mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"), + ) + + // Resolve the dns name 'bad.example.com' to the the instance. + wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "bad.example.com") + + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithResolver(&fakeResolver{ + entries: map[string]instance.ConnName{ + "bad.example.com": wantName, + }, + }), + }, + }) + + // Dial 'bad.example.com'. This will error as 'failed to verify certificate' + _, err := d.Dial( + context.Background(), "bad.example.com", + ) + if err == nil { + t.Fatal("want dial error, got no error") + } + if !strings.Contains(fmt.Sprint(err), "tls: failed to verify certificate") { + t.Fatal("want error containing `tls: failed to verify certificate`. Got: ", err) + } +} diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index e70d5510..6a34570e 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -241,11 +241,22 @@ func (c ConnectionInfo) TLSConfig() *tls.Config { for _, caCert := range c.ServerCACert { pool.AddCert(caCert) } + + // For CAS instances, we can rely on the DNS name to verify the server identity. if c.ServerCAMode != "" && c.ServerCAMode != "GOOGLE_MANAGED_INTERNAL_CA" { // By default, use Standard TLS hostname verification name to // verify the server identity. + + // If the connector was configured with a domain name, use that domain name + // to validate the certificate. Otherwise, use the DNS name from the + // instance ConnectionInfo API response. + serverName := c.ConnectionName.DomainName() + if serverName == "" { + serverName = c.DNSName + } + return &tls.Config{ - ServerName: c.DNSName, + ServerName: serverName, Certificates: []tls.Certificate{c.ClientCertificate}, RootCAs: pool, MinVersion: tls.VersionTLS13, diff --git a/internal/mock/cloudsql.go b/internal/mock/cloudsql.go index 27fc80a3..f871cf28 100644 --- a/internal/mock/cloudsql.go +++ b/internal/mock/cloudsql.go @@ -178,8 +178,14 @@ func WithServerCAMode(serverCAMode string) FakeCSQLInstanceOption { // NewFakeCSQLInstance returns a CloudSQLInst object for configuring mocks. func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance { + return NewFakeCSQLInstanceWithSan(project, region, name, nil, opts...) +} + +// NewFakeCSQLInstanceWithSan returns a CloudSQLInst object for configuring +// mocks, including SubjectAlternativeNames in the server certificate. +func NewFakeCSQLInstanceWithSan(project, region, name string, sanDNSNames []string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance { // TODO: consider options for this? - key, cert, err := generateCerts(project, name) + key, cert, err := generateCerts(project, name, sanDNSNames) if err != nil { panic(err) } @@ -274,7 +280,7 @@ func GenerateCertWithCommonName(i FakeCSQLInstance, cn string) []byte { // generateCerts generates a private key, an X.509 certificate, and a TLS // certificate for a particular fake Cloud SQL database instance. -func generateCerts(project, name string) (*rsa.PrivateKey, *x509.Certificate, error) { +func generateCerts(project, name string, dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, err @@ -291,6 +297,7 @@ func generateCerts(project, name string) (*rsa.PrivateKey, *x509.Certificate, er ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, + DNSNames: dnsNames, } return key, cert, nil