Skip to content
This repository has been archived by the owner on May 2, 2023. It is now read-only.

Commit

Permalink
Add request authentication logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cyakimov committed Apr 16, 2019
1 parent ffa1903 commit d6023ad
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
*.out

*.pem
config.yaml
73 changes: 73 additions & 0 deletions authentication/auth0.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package authentication

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"net/http"
)

type Auth0Provider struct {
OAuth2Provider
oauth2 oauth2.Config
domain string
}

func NewAuth0Provider(config OAuth2Config) OAuth2Provider {
return Auth0Provider{
domain: config.Domain,
oauth2: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.CallbackURL,
Scopes: []string{"openid", "email_verified", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
}
}

func (provider Auth0Provider) GetUserProfile(r *http.Request) (OIDCProfile, error) {
var profile OIDCProfile
code := r.URL.Query().Get("code")

token, err := provider.oauth2.Exchange(context.TODO(), code)
if err != nil {
return profile, ErrCodeExchange
}

// Get user profile
client := provider.oauth2.Client(context.TODO(), token)
resp, err := client.Get(provider.domain + "/userinfo")
if err != nil {
return profile, ErrProfile
}

defer func() {
if resp.Body != nil {
if err := resp.Body.Close(); err != nil {
log.Error(err)
}
}
}()

if err = json.NewDecoder(resp.Body).Decode(&profile); err != nil {
return profile, errors.New("auth0: cannot decode JSON profile")
}

if profile.Email == "" {
return profile, ErrNoEmail
}

return profile, nil
}

func (provider Auth0Provider) GetLoginURL(state string) string {
s := base64.StdEncoding.EncodeToString([]byte(state))
return provider.oauth2.AuthCodeURL(s)
}
97 changes: 97 additions & 0 deletions authentication/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package authentication

import (
"encoding/base64"
"errors"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)

type OAuth2Config struct {
ClientID string
ClientSecret string
CallbackURL string
AuthURL string
TokenURL string
Domain string
}

type OIDCProfile struct {
Email string
}

type OAuth2Provider interface {
GetUserProfile(r *http.Request) (OIDCProfile, error)
GetLoginURL(state string) string
}

const CookieName = "Helios_Authorization"
const HeaderName = "Helios-Jwt-Assertion"

var ErrUnauthorized = errors.New("unauthorized request")
var ErrCodeExchange = errors.New("error on code exchange")
var ErrProfile = errors.New("error getting user profile")
var ErrNoEmail = errors.New("no email found in user profile")

func authenticate(r *http.Request) error {
// look for Token in both Cookies and Headers
_, err := r.Cookie(CookieName)
htoken := r.Header.Get(HeaderName)

if err == http.ErrNoCookie && htoken == "" {
return ErrUnauthorized
}

return nil
}

func Middleware(provider OAuth2Provider, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := authenticate(r); err != nil {
url := provider.GetLoginURL(r.RequestURI)
log.Debugf("Redirecting to %s", url)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
return
}

log.Println(r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}

func CallbackHandler(provider OAuth2Provider, jwtSecret string, jwtDuration time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")

dstate, err := base64.StdEncoding.DecodeString(state)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}

profile, err := provider.GetUserProfile(r)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}

log.Debugf("Authorized. Redirecting to %s", string(dstate))

exp := time.Now().Add(jwtDuration)
jwt, err := IssueJWT(jwtSecret, profile.Email, exp)
http.SetCookie(w, &http.Cookie{
Name: CookieName,
Value: jwt,
Expires: exp,
Path: "/",
Secure: true,
})

http.Redirect(w, r, string(dstate), http.StatusFound)
return
})
}
21 changes: 21 additions & 0 deletions authentication/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package authentication

import (
"github.com/dgrijalva/jwt-go"
"time"
)

func IssueJWT(secret string, email string, expires time.Time) (string, error) {
key := []byte(secret)

// Create the Claims
claims := &jwt.StandardClaims{
ExpiresAt: expires.Unix(),
Subject: email,
Issuer: "Helios",
IssuedAt: time.Now().Unix(),
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(key)
}
17 changes: 17 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,20 @@ routes:
paths:
- path: /
upstream: httpbin
- path: /json
upstream: httpbin
auth_enabled: false

identity:
provider: auth0
client_id: long-hash-here
client_secret: long-hash-here
oauth2:
domain: https://yourtenant.auth0.com
callback: https://localhost/callback
auth_url: https://yourtenant.auth0.com/authorize
token_url: https://yourtenant.auth0.com/oauth/token

jwt:
shared_secret: replace-this-with-a-long-hash
expires: 10h
55 changes: 48 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
)

type Config struct {
Server ServerConfig `yaml:"server"`
Upstreams []Upstream `yaml:"upstreams"`
Routes []Route `yaml:"routes"`
Server Server `yaml:"server"`
Upstreams []Upstream `yaml:"upstreams"`
Routes []Route `yaml:"routes"`
Identity Identity `yaml:"identity"`
JWT JWT `yaml:"jwt"`
}

type ServerConfig struct {
type Server struct {
ListenIP string `yaml:"listen_ip"`
ListenPort int `yaml:"listen_port"`
Timeout time.Duration `yaml:"timeout"`
Expand All @@ -27,8 +29,9 @@ type Route struct {
Host string
HTTP struct {
Paths []struct {
Path string
Upstream string
Path string
Upstream string
AuthEnabled bool `yaml:"auth_enabled"`
}
}
}
Expand All @@ -39,6 +42,23 @@ type Upstream struct {
ConnectTimeout time.Duration
}

type Identity struct {
Provider string `yaml:"provider"`
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
OAuth2 struct {
CallbackURL string `yaml:"callback_url"`
AuthURL string `yaml:"auth_url"`
TokenURL string `yaml:"token_url"`
Domain string `yaml:"domain"`
}
}

type JWT struct {
SharedSecret string
Expires time.Duration
}

func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error {
buf := struct {
ConnectTimeout string `yaml:"connect_timeout"`
Expand All @@ -62,7 +82,28 @@ func (c *Upstream) UnmarshalYAML(unmarshal func(v interface{}) error) error {
return nil
}

func (c *ServerConfig) UnmarshalYAML(unmarshal func(v interface{}) error) error {
func (c *JWT) UnmarshalYAML(unmarshal func(v interface{}) error) error {
buf := struct {
SharedSecret string `yaml:"shared_secret"`
Expires string `yaml:"expires"`
}{}

if err := unmarshal(&buf); err != nil {
return err
}

expires, err := time.ParseDuration(buf.Expires)
if err != nil {
return err
}

c.Expires = expires
c.SharedSecret = buf.SharedSecret

return nil
}

func (c *Server) UnmarshalYAML(unmarshal func(v interface{}) error) error {
var buf struct {
ListenIP string `yaml:"listen_ip"`
ListenPort int `yaml:"listen_port"`
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ module github.com/cyakimov/helios
go 1.12

require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gorilla/mux v1.7.1
github.com/sirupsen/logrus v1.4.1
golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a // indirect
gopkg.in/yaml.v2 v2.2.2
)
15 changes: 15 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/mux v1.7.1 h1:Dw4jY2nghMMRsh1ol8dv1axHkDwMQK2DHerMNJsIpJU=
github.com/gorilla/mux v1.7.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
Expand All @@ -11,9 +16,19 @@ github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMB
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a h1:tImsplftrFpALCYumobsd0K86vlAs/eXGFms2txfJfA=
golang.org/x/oauth2 v0.0.0-20190402181905-9f3314589c9a/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
Expand Down
26 changes: 23 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"flag"
"fmt"
"github.com/cyakimov/helios/authentication"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -37,6 +38,20 @@ func setupRoutes() *mux.Router {
router := mux.NewRouter()
upstreams := make(map[string]*http.Handler, len(config.Upstreams))

oauth2conf := authentication.OAuth2Config{
ClientID: config.Identity.ClientID,
ClientSecret: config.Identity.ClientSecret,
CallbackURL: config.Identity.OAuth2.CallbackURL,
AuthURL: config.Identity.OAuth2.AuthURL,
TokenURL: config.Identity.OAuth2.TokenURL,
Domain: config.Identity.OAuth2.Domain,
}

auth0 := authentication.NewAuth0Provider(oauth2conf)

router.PathPrefix("/.identity/callback").Handler(
authentication.CallbackHandler(auth0, config.JWT.SharedSecret, config.JWT.Expires))

for _, up := range config.Upstreams {
upstreamURL, err := url.Parse(up.URL)
if err != nil {
Expand All @@ -56,14 +71,19 @@ func setupRoutes() *mux.Router {
h := router.Host(route.Host)

for _, path := range route.HTTP.Paths {
up := upstreams[path.Upstream]
upstream := upstreams[path.Upstream]

if up == nil {
if upstream == nil {
log.Fatalf("Upstream %q for route %q not found", path.Upstream, route.Host)
break
}

h.PathPrefix(path.Path).Handler(*up)
if path.AuthEnabled {
h.PathPrefix(path.Path).Handler(authentication.Middleware(auth0, *upstream))
} else {
h.PathPrefix(path.Path).Handler(*upstream)
}

}
}

Expand Down

0 comments on commit d6023ad

Please sign in to comment.