From 11b9b192328fadbbd2c0da8b3163b7a3f4cb8d39 Mon Sep 17 00:00:00 2001 From: Carlos Yakimov Date: Sun, 14 Apr 2019 16:42:00 -0400 Subject: [PATCH] Implement request authentication with auth0 --- authentication/auth.go | 110 ++++++++++++++++++++++++ authentication/identity.go | 97 --------------------- authentication/{ => providers}/auth0.go | 33 ++++--- authentication/providers/provider.go | 27 ++++++ authentication/token.go | 18 +++- config.example.yaml | 16 ++-- config.go | 17 ++-- go.mod | 2 +- go.sum | 4 +- main.go | 43 +++++---- 10 files changed, 220 insertions(+), 147 deletions(-) create mode 100644 authentication/auth.go delete mode 100644 authentication/identity.go rename authentication/{ => providers}/auth0.go (61%) create mode 100644 authentication/providers/provider.go diff --git a/authentication/auth.go b/authentication/auth.go new file mode 100644 index 0000000..21296a9 --- /dev/null +++ b/authentication/auth.go @@ -0,0 +1,110 @@ +package authentication + +import ( + "encoding/base64" + "errors" + "github.com/cyakimov/helios/authentication/providers" + log "github.com/sirupsen/logrus" + "net/http" + "time" +) + +const CookieName = "Helios_Authorization" +const HeaderName = "Helios-Jwt-Assertion" + +var ErrUnauthorized = errors.New("unauthorized request") + +type JWTOpts struct { + Secret string + Expiration time.Duration +} + +type Helios struct { + provider providers.OAuth2 + jwtOpts JWTOpts +} + +func NewHeliosAuthentication(provider providers.OAuth2, jwtSecret string, jwtExpiration time.Duration) Helios { + return Helios{ + provider: provider, + jwtOpts: JWTOpts{ + Secret: jwtSecret, + Expiration: jwtExpiration, + }, + } +} + +func (helios Helios) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := authenticate(helios.jwtOpts.Secret, r); err != nil { + + // dynamically build callback URL based on current domain + callback := "https://" + r.Host + "/.oauth2/callback" + + url := helios.provider.GetLoginURL(callback, r.RequestURI) + + log.Debugf("Redirecting to %s", url) + + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + + log.Println(r.RequestURI) + // Call the next handler, which can be another middleware in the chain, or the final handler. + next.ServeHTTP(w, r) + }) +} + +func (helios Helios) CallbackHandler(w http.ResponseWriter, r *http.Request) { + // decode and decrypt state to recover original request url + encodedState := r.URL.Query().Get("state") + + state, err := base64.StdEncoding.DecodeString(encodedState) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + profile, err := helios.provider.GetUserProfile(r) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + log.Debugf("Authorized. Redirecting to %s", string(state)) + + exp := time.Now().Add(helios.jwtOpts.Expiration) + jwt, err := IssueJWTWithSecret(helios.jwtOpts.Secret, profile.Email, exp) + http.SetCookie(w, &http.Cookie{ + Name: CookieName, + Value: jwt, + Expires: exp, + Path: "/", + Secure: true, + }) + + http.Redirect(w, r, string(state), http.StatusFound) + return +} + +func authenticate(jwtSecret string, r *http.Request) error { + // look for Token in both Cookies and Headers + cookie, err := r.Cookie(CookieName) + token := r.Header.Get(HeaderName) + + if err == http.ErrNoCookie && token == "" { + return ErrUnauthorized + } + + if token == "" { + token = cookie.Value + } + + if !ValidateJWTWithSecret(jwtSecret, token) { + return ErrUnauthorized + } + + return nil +} diff --git a/authentication/identity.go b/authentication/identity.go deleted file mode 100644 index 8ae17f9..0000000 --- a/authentication/identity.go +++ /dev/null @@ -1,97 +0,0 @@ -package authentication - -import ( - "encoding/base64" - "errors" - log "github.com/sirupsen/logrus" - "net/http" - "time" -) - -type OAuth2Config struct { - ClientID string - ClientSecret string - CallbackURL string - AuthURL string - TokenURL string - Domain string -} - -type OIDCProfile struct { - Email string -} - -type OAuth2Provider interface { - GetUserProfile(r *http.Request) (OIDCProfile, error) - GetLoginURL(state string) string -} - -const CookieName = "Helios_Authorization" -const HeaderName = "Helios-Jwt-Assertion" - -var ErrUnauthorized = errors.New("unauthorized request") -var ErrCodeExchange = errors.New("error on code exchange") -var ErrProfile = errors.New("error getting user profile") -var ErrNoEmail = errors.New("no email found in user profile") - -func authenticate(r *http.Request) error { - // look for Token in both Cookies and Headers - _, err := r.Cookie(CookieName) - htoken := r.Header.Get(HeaderName) - - if err == http.ErrNoCookie && htoken == "" { - return ErrUnauthorized - } - - return nil -} - -func Middleware(provider OAuth2Provider, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := authenticate(r); err != nil { - url := provider.GetLoginURL(r.RequestURI) - log.Debugf("Redirecting to %s", url) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) - return - } - - log.Println(r.RequestURI) - // Call the next handler, which can be another middleware in the chain, or the final handler. - next.ServeHTTP(w, r) - }) -} - -func CallbackHandler(provider OAuth2Provider, jwtSecret string, jwtDuration time.Duration) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - state := r.URL.Query().Get("state") - - dstate, err := base64.StdEncoding.DecodeString(state) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - profile, err := provider.GetUserProfile(r) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - log.Debugf("Authorized. Redirecting to %s", string(dstate)) - - exp := time.Now().Add(jwtDuration) - jwt, err := IssueJWT(jwtSecret, profile.Email, exp) - http.SetCookie(w, &http.Cookie{ - Name: CookieName, - Value: jwt, - Expires: exp, - Path: "/", - Secure: true, - }) - - http.Redirect(w, r, string(dstate), http.StatusFound) - return - }) -} diff --git a/authentication/auth0.go b/authentication/providers/auth0.go similarity index 61% rename from authentication/auth0.go rename to authentication/providers/auth0.go index 942e4d6..f3232af 100644 --- a/authentication/auth0.go +++ b/authentication/providers/auth0.go @@ -1,4 +1,4 @@ -package authentication +package providers import ( "context" @@ -11,19 +11,19 @@ import ( ) type Auth0Provider struct { - OAuth2Provider - oauth2 oauth2.Config - domain string + OAuth2 + oauth2 oauth2.Config + profileURL string } -func NewAuth0Provider(config OAuth2Config) OAuth2Provider { +func NewAuth0Provider(config OAuth2Config) OAuth2 { return Auth0Provider{ - domain: config.Domain, + profileURL: config.ProfileURL, oauth2: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, - RedirectURL: config.CallbackURL, Scopes: []string{"openid", "email_verified", "email"}, + RedirectURL: "", // RedirectURL can vary per route host Endpoint: oauth2.Endpoint{ AuthURL: config.AuthURL, TokenURL: config.TokenURL, @@ -36,14 +36,20 @@ func (provider Auth0Provider) GetUserProfile(r *http.Request) (OIDCProfile, erro var profile OIDCProfile code := r.URL.Query().Get("code") - token, err := provider.oauth2.Exchange(context.TODO(), code) + // Auth0 requires callback URL + url := "https://" + r.Host + "/" + r.URL.Path + callback := oauth2.SetAuthURLParam("redirect_uri", url) + + // get access token + token, err := provider.oauth2.Exchange(context.TODO(), code, callback) if err != nil { + log.Error(err) return profile, ErrCodeExchange } - // Get user profile + // get user profile client := provider.oauth2.Client(context.TODO(), token) - resp, err := client.Get(provider.domain + "/userinfo") + resp, err := client.Get(provider.profileURL) if err != nil { return profile, ErrProfile } @@ -67,7 +73,10 @@ func (provider Auth0Provider) GetUserProfile(r *http.Request) (OIDCProfile, erro return profile, nil } -func (provider Auth0Provider) GetLoginURL(state string) string { +func (provider Auth0Provider) GetLoginURL(callbackURL string, state string) string { s := base64.StdEncoding.EncodeToString([]byte(state)) - return provider.oauth2.AuthCodeURL(s) + + callback := oauth2.SetAuthURLParam("redirect_uri", callbackURL) + + return provider.oauth2.AuthCodeURL(s, callback) } diff --git a/authentication/providers/provider.go b/authentication/providers/provider.go new file mode 100644 index 0000000..18d1f3a --- /dev/null +++ b/authentication/providers/provider.go @@ -0,0 +1,27 @@ +package providers + +import ( + "errors" + "net/http" +) + +type OAuth2Config struct { + ClientID string + ClientSecret string + AuthURL string + TokenURL string + ProfileURL string +} + +type OIDCProfile struct { + Email string +} + +type OAuth2 interface { + GetUserProfile(r *http.Request) (OIDCProfile, error) + GetLoginURL(callbackURL, state string) string +} + +var ErrCodeExchange = errors.New("error on code exchange") +var ErrProfile = errors.New("error getting user profile") +var ErrNoEmail = errors.New("no email found in user profile") diff --git a/authentication/token.go b/authentication/token.go index ba87fc0..7e991bd 100644 --- a/authentication/token.go +++ b/authentication/token.go @@ -1,11 +1,13 @@ package authentication import ( + "fmt" "github.com/dgrijalva/jwt-go" "time" ) -func IssueJWT(secret string, email string, expires time.Time) (string, error) { +// IssueJWTWithSecret issues and sign a JWT with a secret +func IssueJWTWithSecret(secret, email string, expires time.Time) (string, error) { key := []byte(secret) // Create the Claims @@ -19,3 +21,17 @@ func IssueJWT(secret string, email string, expires time.Time) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(key) } + +// ValidateJWTWithSecret checks JWT signing algorithm as well the signature +func ValidateJWTWithSecret(secret, tokenString string) bool { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Validate the alg + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secret), nil + }) + + return err == nil && token != nil && token.Valid +} diff --git a/config.example.yaml b/config.example.yaml index a5eb849..8871e4c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -19,20 +19,24 @@ routes: paths: - path: / upstream: httpbin - - path: /json - upstream: httpbin - auth_enabled: false + + - host: 127.0.0.1 + http: + paths: + - path: / + upstream: httpbin + auth_enabled: false identity: provider: auth0 client_id: long-hash-here client_secret: long-hash-here oauth2: - domain: https://yourtenant.auth0.com - callback: https://localhost/callback auth_url: https://yourtenant.auth0.com/authorize token_url: https://yourtenant.auth0.com/oauth/token + profile_url: https://yourtenant.auth0.com/userinfo + state_secret: long-hash-here jwt: - shared_secret: replace-this-with-a-long-hash + secret: replace-this-with-a-long-hash expires: 10h diff --git a/config.go b/config.go index dddcb17..311e2f8 100644 --- a/config.go +++ b/config.go @@ -47,16 +47,15 @@ type Identity struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` OAuth2 struct { - CallbackURL string `yaml:"callback_url"` - AuthURL string `yaml:"auth_url"` - TokenURL string `yaml:"token_url"` - Domain string `yaml:"domain"` + AuthURL string `yaml:"auth_url"` + TokenURL string `yaml:"token_url"` + ProfileURL string `yaml:"profile_url"` } } type JWT struct { - SharedSecret string - Expires time.Duration + Secret string + Expires time.Duration } func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error { @@ -84,8 +83,8 @@ func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error { func (c *JWT) UnmarshalYAML(unmarshal func(v interface{}) error) error { buf := struct { - SharedSecret string `yaml:"shared_secret"` - Expires string `yaml:"expires"` + Secret string `yaml:"secret"` + Expires string `yaml:"expires"` }{} if err := unmarshal(&buf); err != nil { @@ -98,7 +97,7 @@ func (c *JWT) UnmarshalYAML(unmarshal func(v interface{}) error) error { } c.Expires = expires - c.SharedSecret = buf.SharedSecret + c.Secret = buf.Secret return nil } diff --git a/go.mod b/go.mod index 1714fdf..8ded5d1 100644 --- a/go.mod +++ b/go.mod @@ -8,5 +8,5 @@ require ( github.com/sirupsen/logrus v1.4.1 golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect - gopkg.in/yaml.v2 v2.2.2 + gopkg.in/yaml.v3 v3.0.0-20190409140830-cdc409dda467 ) diff --git a/go.sum b/go.sum index f540548..28adc2d 100644 --- a/go.sum +++ b/go.sum @@ -31,5 +31,5 @@ google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO50 google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20190409140830-cdc409dda467 h1:w3VhdSYz2sIVz54Ta/eDCCfCQ4fQkDgRxMACggArIUw= +gopkg.in/yaml.v3 v3.0.0-20190409140830-cdc409dda467/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 95423e1..ac10a33 100644 --- a/main.go +++ b/main.go @@ -6,9 +6,10 @@ import ( "flag" "fmt" "github.com/cyakimov/helios/authentication" + "github.com/cyakimov/helios/authentication/providers" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "io/ioutil" "net/http" "net/url" @@ -23,34 +24,41 @@ var ( configPath string config *Config tlsConfig *tls.Config + debugMode bool ) func init() { + // Enable TLS 1.3 _ = os.Setenv("GODEBUG", os.Getenv("GODEBUG")+",tls13=1") flag.StringVar(&configPath, "config", "default.yaml", "Configuration file path") + flag.BoolVar(&debugMode, "verbose", false, "DEBUG level logging") flag.Parse() - log.SetLevel(log.DebugLevel) + if debugMode { + log.SetLevel(log.DebugLevel) + } else { + log.SetLevel(log.InfoLevel) + } } -func setupRoutes() *mux.Router { +func router() *mux.Router { router := mux.NewRouter() - upstreams := make(map[string]*http.Handler, len(config.Upstreams)) + upstreams := make(map[string]http.Handler, len(config.Upstreams)) - oauth2conf := authentication.OAuth2Config{ + oauth2conf := providers.OAuth2Config{ ClientID: config.Identity.ClientID, ClientSecret: config.Identity.ClientSecret, - CallbackURL: config.Identity.OAuth2.CallbackURL, AuthURL: config.Identity.OAuth2.AuthURL, TokenURL: config.Identity.OAuth2.TokenURL, - Domain: config.Identity.OAuth2.Domain, + ProfileURL: config.Identity.OAuth2.ProfileURL, } - auth0 := authentication.NewAuth0Provider(oauth2conf) + auth0 := providers.NewAuth0Provider(oauth2conf) + + auth := authentication.NewHeliosAuthentication(auth0, config.JWT.Secret, config.JWT.Expires) - router.PathPrefix("/.identity/callback").Handler( - authentication.CallbackHandler(auth0, config.JWT.SharedSecret, config.JWT.Expires)) + router.PathPrefix("/.oauth2/callback").HandlerFunc(auth.CallbackHandler) for _, up := range config.Upstreams { upstreamURL, err := url.Parse(up.URL) @@ -64,7 +72,7 @@ func setupRoutes() *mux.Router { Timeout: config.Server.Timeout, } proxy := NewSingleHostReverseProxy(upstreamURL, conf) - upstreams[up.Name] = &proxy + upstreams[up.Name] = proxy } for _, route := range config.Routes { @@ -79,9 +87,9 @@ func setupRoutes() *mux.Router { } if path.AuthEnabled { - h.PathPrefix(path.Path).Handler(authentication.Middleware(auth0, *upstream)) + h.PathPrefix(path.Path).Handler(auth.Middleware(upstream)) } else { - h.PathPrefix(path.Path).Handler(*upstream) + h.PathPrefix(path.Path).Handler(upstream) } } @@ -91,13 +99,12 @@ func setupRoutes() *mux.Router { } func main() { - // Set sane defaults cb, err := ioutil.ReadFile(configPath) if err != nil { log.Fatalf("Error loading configuration: %v", err) } - if err = yaml.UnmarshalStrict(cb, &config); err != nil { + if err = yaml.Unmarshal(cb, &config); err != nil { log.Fatalf("Error parsing configuration: %v", err) } @@ -108,8 +115,6 @@ func main() { MaxVersion: tls.VersionTLS13, } - router := setupRoutes() - address := fmt.Sprintf("%s:%d", config.Server.ListenIP, config.Server.ListenPort) srv := &http.Server{ Addr: address, @@ -117,8 +122,8 @@ func main() { ReadTimeout: config.Server.Timeout, IdleTimeout: config.Server.IdleTimeout, TLSConfig: tlsConfig, - MaxHeaderBytes: 1 << 20, - Handler: router, + MaxHeaderBytes: 1 << 20, // 1mb + Handler: router(), } // Run our server in a goroutine so that it doesn't block.