Skip to content

Commit

Permalink
Merge pull request #345 from AzureAD/release-0.6.1
Browse files Browse the repository at this point in the history
Release 0.6.1
  • Loading branch information
rayluo authored Aug 18, 2022
2 parents a04b770 + efebb96 commit 8d382bd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 7 deletions.
33 changes: 27 additions & 6 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,23 @@ type Credential struct {
// code requires that client.go, requests.go and confidential.go share a credential type without
// having import recursion. That requires the type used between is in a shared package. Therefore
// we have this.
func (c Credential) toInternal() *accesstokens.Credential {
return &accesstokens.Credential{Secret: c.secret, Cert: c.cert, Key: c.key, AssertionCallback: c.assertionCallback, X5c: c.x5c}
func (c Credential) toInternal() (*accesstokens.Credential, error) {
if c.secret != "" {
return &accesstokens.Credential{Secret: c.secret}, nil
}
if c.cert != nil {
if c.key == nil {
return nil, errors.New("missing private key for certificate")
}
return &accesstokens.Credential{Cert: c.cert, Key: c.key, X5c: c.x5c}, nil
}
if c.key != nil {
return nil, errors.New("missing certificate for private key")
}
if c.assertionCallback != nil {
return &accesstokens.Credential{AssertionCallback: c.assertionCallback}, nil
}
return nil, errors.New("invalid credential")
}

// NewCredFromSecret creates a Credential from a secret.
Expand Down Expand Up @@ -191,6 +206,10 @@ func NewCredFromCertChain(certs []*x509.Certificate, key crypto.PrivateKey) (Cre
return cred, errors.New("key must be an RSA key")
}
for _, cert := range certs {
if cert == nil {
// not returning an error here because certs may still contain a sufficient cert/key pair
continue
}
certKey, ok := cert.PublicKey.(*rsa.PublicKey)
if ok && k.E == certKey.E && k.N.Cmp(certKey.N) == 0 {
// We know this is the signing cert because its public key matches the given private key.
Expand Down Expand Up @@ -312,6 +331,11 @@ func WithAzureRegion(val string) Option {
// will store credentials for (a Client is per user). clientID is the Azure clientID and cred is
// the type of credential to use.
func New(clientID string, cred Credential, options ...Option) (Client, error) {
internalCred, err := cred.toInternal()
if err != nil {
return Client{}, err
}

opts := Options{
Authority: base.AuthorityPublicCloud,
HTTPClient: shared.DefaultClient,
Expand All @@ -329,10 +353,7 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) {
return Client{}, err
}

return Client{
base: base,
cred: cred.toInternal(),
}, nil
return Client{base: base, cred: internalCred}, nil
}

// UserID is the unique user identifier this client if for.
Expand Down
56 changes: 56 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package confidential

import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -239,6 +240,31 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
}
}

func TestInvalidCredential(t *testing.T) {
data, err := os.ReadFile("../testdata/test-cert.pem")
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(data, "")
if err != nil {
t.Fatal(err)
}
for _, cred := range []Credential{
{},
NewCredFromAssertionCallback(nil),
NewCredFromCert(nil, nil),
NewCredFromCert(certs[0], nil),
NewCredFromCert(nil, key),
} {
t.Run("", func(t *testing.T) {
_, err := New("client-id", cred)
if err == nil {
t.Fatal("expected an error")
}
})
}
}

func TestNewCredFromCertChain(t *testing.T) {
for _, file := range []struct {
path string
Expand Down Expand Up @@ -355,3 +381,33 @@ func TestNewCredFromCertChain(t *testing.T) {
}
}
}

func TestNewCredFromCertChainError(t *testing.T) {
data, err := os.ReadFile("../testdata/test-cert.pem")
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(data, "")
if err != nil {
t.Fatal(err)
}
for _, test := range []struct {
certs []*x509.Certificate
key crypto.PrivateKey
}{
{nil, nil},
{certs, nil},
{nil, key},
{[]*x509.Certificate{}, nil},
{[]*x509.Certificate{}, key},
{[]*x509.Certificate{nil}, nil},
{[]*x509.Certificate{nil}, key},
} {
t.Run("", func(t *testing.T) {
_, err := NewCredFromCertChain(test.certs, test.key)
if err == nil {
t.Fatal("expected an error")
}
})
}
}
2 changes: 1 addition & 1 deletion apps/internal/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
package version

// Version is the version of this client package that is communicated to the server.
const Version = "0.6.0"
const Version = "0.6.1"

0 comments on commit 8d382bd

Please sign in to comment.