Skip to content
This repository has been archived by the owner on Jan 24, 2019. It is now read-only.

Pass UserId as header #276

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Usage of oauth2_proxy:
-login-url="": Authentication endpoint
-pass-access-token=false: pass OAuth access_token to upstream via X-Forwarded-Access-Token header
-pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream
-pass-user-id=false, pass user's identifier as X-Forwarded-UserId
-pass-host-header=true: pass the request Host Header to upstream
-profile-url="": Profile access endpoint
-provider="google": OAuth provider
Expand Down
4 changes: 4 additions & 0 deletions contrib/oauth2_proxy.cfg.example
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream
# pass_basic_auth = true

## Pass user identifier as X-Forwarded-UserId
# pass_user_id = false

## pass the request Host Header to upstream
## when disabled the upstream Host is used as the Host Header
# pass_host_header = true
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func main() {
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path")
flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream")
flagSet.Bool("pass-user-id", false, "pass user's identifier as X-Forwarded-UserId")
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
Expand Down
13 changes: 12 additions & 1 deletion oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type OAuthProxy struct {
DisplayHtpasswdForm bool
serveMux http.Handler
PassBasicAuth bool
PassUserId bool
SkipProviderButton bool
BasicAuthPassword string
PassAccessToken bool
Expand Down Expand Up @@ -194,11 +195,12 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
skipAuthRegex: opts.SkipAuthRegex,
compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth,
PassUserId: opts.PassUserId,
BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken,
SkipProviderButton: opts.SkipProviderButton,
CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
SkipProviderButton: opts.SkipProviderButton,
Footer: opts.Footer,
}
}
Expand Down Expand Up @@ -238,6 +240,9 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s)
}

s.UserId, err = p.provider.GetUserId(s)

return
}

Expand Down Expand Up @@ -602,6 +607,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
req.Header["X-Forwarded-Email"] = []string{session.Email}
}
}

if p.PassUserId {
log.Printf("Passing UserId header %s", session.UserId)
req.Header["X-Forwarded-UserId"] = []string{session.UserId}
}

if p.PassAccessToken && session.AccessToken != "" {
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
}
Expand Down
1 change: 1 addition & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Options struct {
Upstreams []string `flag:"upstream" cfg:"upstreams"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"`
PassUserId bool `flag:"pass-user-id" cfg:"pass_user_id"`
BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"`
PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"`
PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"`
Expand Down
43 changes: 43 additions & 0 deletions providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"net/http"
"net/url"
"strconv"
"path"
"strings"
)
Expand Down Expand Up @@ -234,3 +235,45 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {

return "", nil
}

func (p *GitHubProvider) GetUserId(s *SessionState) (string, error) {

var userData struct {
UserId int `json:"id"`
}

params := url.Values{
"access_token": {s.AccessToken},
}

endpoint := &url.URL{
Scheme: p.ValidateURL.Scheme,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
RawQuery: params.Encode(),
}
resp, err := http.DefaultClient.Get(endpoint.String())

if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)

if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint, body)
}

if err := json.Unmarshal(body, &userData); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}

var id = strconv.Itoa(userData.UserId)

log.Printf("User ID is", id)

return id, nil

}

4 changes: 4 additions & 0 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@ func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return false, nil
}

func (p *ProviderData) GetUserId(s *SessionState) (string, error) {
return "", nil
}
1 change: 1 addition & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
type Provider interface {
Data() *ProviderData
GetEmailAddress(*SessionState) (string, error)
GetUserId(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error)
ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool
Expand Down
25 changes: 21 additions & 4 deletions providers/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strconv"
"strings"
"time"

"github.com/bitly/oauth2_proxy/cookie"
)

Expand All @@ -15,6 +14,7 @@ type SessionState struct {
RefreshToken string
Email string
User string
UserId string
}

func (s *SessionState) IsExpired() bool {
Expand Down Expand Up @@ -72,7 +72,16 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
return "", err
}
}
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil

u := s.UserId
if u != "" {
u, err = c.Encrypt(u)
if err != nil {
return "", err
}
}

return fmt.Sprintf("%s|%s|%d|%s|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r, u), nil
}

func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
Expand All @@ -85,8 +94,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
return &SessionState{User: v}, nil
}

if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
if len(chunks) != 5 {
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks))
return
}

Expand All @@ -103,6 +112,14 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
return nil, err
}
}

if c!=nil && chunks[4] != "" {
s.UserId, err = c.Decrypt(chunks[4])
if err !=nil {
return nil, err
}
}

if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
Expand Down
4 changes: 3 additions & 1 deletion providers/session_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ func TestSessionStateSerialization(t *testing.T) {
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
UserId: "1",

}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
assert.Equal(t, 4, strings.Count(encoded, "|"))

ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
Expand Down