From d6023ada5467f3d1be4b1d3b7897ae1ff1e9e838 Mon Sep 17 00:00:00 2001 From: Carlos Yakimov Date: Tue, 9 Apr 2019 23:01:41 -0400 Subject: [PATCH] Add request authentication logic --- .gitignore | 1 + authentication/auth0.go | 73 ++++++++++++++++++++++++++++ authentication/identity.go | 97 ++++++++++++++++++++++++++++++++++++++ authentication/token.go | 21 +++++++++ config.example.yaml | 17 +++++++ config.go | 55 ++++++++++++++++++--- go.mod | 2 + go.sum | 15 ++++++ main.go | 26 ++++++++-- 9 files changed, 297 insertions(+), 10 deletions(-) create mode 100644 authentication/auth0.go create mode 100644 authentication/identity.go create mode 100644 authentication/token.go diff --git a/.gitignore b/.gitignore index e2f6b5a..919fda2 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ *.out *.pem +config.yaml diff --git a/authentication/auth0.go b/authentication/auth0.go new file mode 100644 index 0000000..942e4d6 --- /dev/null +++ b/authentication/auth0.go @@ -0,0 +1,73 @@ +package authentication + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + log "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + "net/http" +) + +type Auth0Provider struct { + OAuth2Provider + oauth2 oauth2.Config + domain string +} + +func NewAuth0Provider(config OAuth2Config) OAuth2Provider { + return Auth0Provider{ + domain: config.Domain, + oauth2: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.CallbackURL, + Scopes: []string{"openid", "email_verified", "email"}, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + } +} + +func (provider Auth0Provider) GetUserProfile(r *http.Request) (OIDCProfile, error) { + var profile OIDCProfile + code := r.URL.Query().Get("code") + + token, err := provider.oauth2.Exchange(context.TODO(), code) + if err != nil { + return profile, ErrCodeExchange + } + + // Get user profile + client := provider.oauth2.Client(context.TODO(), token) + resp, err := client.Get(provider.domain + "/userinfo") + if err != nil { + return profile, ErrProfile + } + + defer func() { + if resp.Body != nil { + if err := resp.Body.Close(); err != nil { + log.Error(err) + } + } + }() + + if err = json.NewDecoder(resp.Body).Decode(&profile); err != nil { + return profile, errors.New("auth0: cannot decode JSON profile") + } + + if profile.Email == "" { + return profile, ErrNoEmail + } + + return profile, nil +} + +func (provider Auth0Provider) GetLoginURL(state string) string { + s := base64.StdEncoding.EncodeToString([]byte(state)) + return provider.oauth2.AuthCodeURL(s) +} diff --git a/authentication/identity.go b/authentication/identity.go new file mode 100644 index 0000000..8ae17f9 --- /dev/null +++ b/authentication/identity.go @@ -0,0 +1,97 @@ +package authentication + +import ( + "encoding/base64" + "errors" + log "github.com/sirupsen/logrus" + "net/http" + "time" +) + +type OAuth2Config struct { + ClientID string + ClientSecret string + CallbackURL string + AuthURL string + TokenURL string + Domain string +} + +type OIDCProfile struct { + Email string +} + +type OAuth2Provider interface { + GetUserProfile(r *http.Request) (OIDCProfile, error) + GetLoginURL(state string) string +} + +const CookieName = "Helios_Authorization" +const HeaderName = "Helios-Jwt-Assertion" + +var ErrUnauthorized = errors.New("unauthorized request") +var ErrCodeExchange = errors.New("error on code exchange") +var ErrProfile = errors.New("error getting user profile") +var ErrNoEmail = errors.New("no email found in user profile") + +func authenticate(r *http.Request) error { + // look for Token in both Cookies and Headers + _, err := r.Cookie(CookieName) + htoken := r.Header.Get(HeaderName) + + if err == http.ErrNoCookie && htoken == "" { + return ErrUnauthorized + } + + return nil +} + +func Middleware(provider OAuth2Provider, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := authenticate(r); err != nil { + url := provider.GetLoginURL(r.RequestURI) + log.Debugf("Redirecting to %s", url) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) + return + } + + log.Println(r.RequestURI) + // Call the next handler, which can be another middleware in the chain, or the final handler. + next.ServeHTTP(w, r) + }) +} + +func CallbackHandler(provider OAuth2Provider, jwtSecret string, jwtDuration time.Duration) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + + dstate, err := base64.StdEncoding.DecodeString(state) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + profile, err := provider.GetUserProfile(r) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + log.Debugf("Authorized. Redirecting to %s", string(dstate)) + + exp := time.Now().Add(jwtDuration) + jwt, err := IssueJWT(jwtSecret, profile.Email, exp) + http.SetCookie(w, &http.Cookie{ + Name: CookieName, + Value: jwt, + Expires: exp, + Path: "/", + Secure: true, + }) + + http.Redirect(w, r, string(dstate), http.StatusFound) + return + }) +} diff --git a/authentication/token.go b/authentication/token.go new file mode 100644 index 0000000..ba87fc0 --- /dev/null +++ b/authentication/token.go @@ -0,0 +1,21 @@ +package authentication + +import ( + "github.com/dgrijalva/jwt-go" + "time" +) + +func IssueJWT(secret string, email string, expires time.Time) (string, error) { + key := []byte(secret) + + // Create the Claims + claims := &jwt.StandardClaims{ + ExpiresAt: expires.Unix(), + Subject: email, + Issuer: "Helios", + IssuedAt: time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(key) +} diff --git a/config.example.yaml b/config.example.yaml index 658d880..a5eb849 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -19,3 +19,20 @@ routes: paths: - path: / upstream: httpbin + - path: /json + upstream: httpbin + auth_enabled: false + +identity: + provider: auth0 + client_id: long-hash-here + client_secret: long-hash-here + oauth2: + domain: https://yourtenant.auth0.com + callback: https://localhost/callback + auth_url: https://yourtenant.auth0.com/authorize + token_url: https://yourtenant.auth0.com/oauth/token + +jwt: + shared_secret: replace-this-with-a-long-hash + expires: 10h diff --git a/config.go b/config.go index 1b5dc7e..dddcb17 100644 --- a/config.go +++ b/config.go @@ -5,12 +5,14 @@ import ( ) type Config struct { - Server ServerConfig `yaml:"server"` - Upstreams []Upstream `yaml:"upstreams"` - Routes []Route `yaml:"routes"` + Server Server `yaml:"server"` + Upstreams []Upstream `yaml:"upstreams"` + Routes []Route `yaml:"routes"` + Identity Identity `yaml:"identity"` + JWT JWT `yaml:"jwt"` } -type ServerConfig struct { +type Server struct { ListenIP string `yaml:"listen_ip"` ListenPort int `yaml:"listen_port"` Timeout time.Duration `yaml:"timeout"` @@ -27,8 +29,9 @@ type Route struct { Host string HTTP struct { Paths []struct { - Path string - Upstream string + Path string + Upstream string + AuthEnabled bool `yaml:"auth_enabled"` } } } @@ -39,6 +42,23 @@ type Upstream struct { ConnectTimeout time.Duration } +type Identity struct { + Provider string `yaml:"provider"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + OAuth2 struct { + CallbackURL string `yaml:"callback_url"` + AuthURL string `yaml:"auth_url"` + TokenURL string `yaml:"token_url"` + Domain string `yaml:"domain"` + } +} + +type JWT struct { + SharedSecret string + Expires time.Duration +} + func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error { buf := struct { ConnectTimeout string `yaml:"connect_timeout"` @@ -62,7 +82,28 @@ func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error { return nil } -func (c *ServerConfig) UnmarshalYAML(unmarshal func(v interface{}) error) error { +func (c *JWT) UnmarshalYAML(unmarshal func(v interface{}) error) error { + buf := struct { + SharedSecret string `yaml:"shared_secret"` + Expires string `yaml:"expires"` + }{} + + if err := unmarshal(&buf); err != nil { + return err + } + + expires, err := time.ParseDuration(buf.Expires) + if err != nil { + return err + } + + c.Expires = expires + c.SharedSecret = buf.SharedSecret + + return nil +} + +func (c *Server) UnmarshalYAML(unmarshal func(v interface{}) error) error { var buf struct { ListenIP string `yaml:"listen_ip"` ListenPort int `yaml:"listen_port"` diff --git a/go.mod b/go.mod index 5709ba0..1714fdf 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/cyakimov/helios go 1.12 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gorilla/mux v1.7.1 github.com/sirupsen/logrus v1.4.1 + golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index 81f7bd4..f540548 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,10 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/gorilla/mux v1.7.1 h1:Dw4jY2nghMMRsh1ol8dv1axHkDwMQK2DHerMNJsIpJU= github.com/gorilla/mux v1.7.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= @@ -11,9 +16,19 @@ github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMB github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a h1:tImsplftrFpALCYumobsd0K86vlAs/eXGFms2txfJfA= +golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/main.go b/main.go index 4f33734..95423e1 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "flag" "fmt" + "github.com/cyakimov/helios/authentication" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -37,6 +38,20 @@ func setupRoutes() *mux.Router { router := mux.NewRouter() upstreams := make(map[string]*http.Handler, len(config.Upstreams)) + oauth2conf := authentication.OAuth2Config{ + ClientID: config.Identity.ClientID, + ClientSecret: config.Identity.ClientSecret, + CallbackURL: config.Identity.OAuth2.CallbackURL, + AuthURL: config.Identity.OAuth2.AuthURL, + TokenURL: config.Identity.OAuth2.TokenURL, + Domain: config.Identity.OAuth2.Domain, + } + + auth0 := authentication.NewAuth0Provider(oauth2conf) + + router.PathPrefix("/.identity/callback").Handler( + authentication.CallbackHandler(auth0, config.JWT.SharedSecret, config.JWT.Expires)) + for _, up := range config.Upstreams { upstreamURL, err := url.Parse(up.URL) if err != nil { @@ -56,14 +71,19 @@ func setupRoutes() *mux.Router { h := router.Host(route.Host) for _, path := range route.HTTP.Paths { - up := upstreams[path.Upstream] + upstream := upstreams[path.Upstream] - if up == nil { + if upstream == nil { log.Fatalf("Upstream %q for route %q not found", path.Upstream, route.Host) break } - h.PathPrefix(path.Path).Handler(*up) + if path.AuthEnabled { + h.PathPrefix(path.Path).Handler(authentication.Middleware(auth0, *upstream)) + } else { + h.PathPrefix(path.Path).Handler(*upstream) + } + } }