Skip to content
This repository has been archived by the owner on Jan 24, 2019. It is now read-only.

1) Add websocket support, and 2) make sure to redirect to where we came from #64

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions gap.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
## Google Auth Proxy Config File
## https://github.com/bitly/google_auth_proxy

## <addr>:<port> to listen on for HTTP clients
# http_address = "127.0.0.1:4180"

## the OAuth Redirect URL.
redirect_url = "https://auth.int.treatwell.com/oauth2/callback"

## the http url(s) of the upstream endpoint. If multiple, routing is based on path
upstreams = [
"http://127.0.0.1:8080/"
]

## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream
# pass_basic_auth = true

## Google Apps Domains to allow authentication for
google_apps_domains = [
"treatwell.nl"
]


## The Google OAuth Client ID, Secret
client_id = "772435130105-ppfk945bgmv2oejestd936cpkqgh2h6p.apps.googleusercontent.com"
client_secret = "G_fMH57Mp6gOh7M1XJyY5oie"

## Authenticated Email Addresses File (one email per line)
# authenticated_emails_file = ""

## Htpasswd File (optional)
## Additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption
## enabling exposes a username/login signin form
#htpasswd_file = "/opt/htpasswd"


## Cookie Settings
## Secret - the seed string for secure cookies
## Domain - optional cookie domain to force cookies to (ie: .yourcompany.com)
## Expire - expire timeframe for cookie
# cookie_secret = ""
cookie_domain = "int.treatwell.com"
cookie_expire = "168h"
# cookie_https_only = true
# cookie_httponly = true

40 changes: 26 additions & 14 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io/ioutil"
"log"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
Expand Down Expand Up @@ -54,7 +53,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
path := u.Path
u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u)
serveMux.Handle(path, httputil.NewSingleHostReverseProxy(u))
serveMux.Handle(path, NewWebsocketReverseProxy(u))
}
for _, u := range opts.CompiledRegex {
log.Printf("compiled skip-auth-regex => %q", u)
Expand Down Expand Up @@ -98,9 +97,7 @@ func (p *OauthProxy) GetLoginURL(redirectUrl string) string {
params.Add("scope", p.oauthScope)
params.Add("client_id", p.clientID)
params.Add("response_type", "code")
if strings.HasPrefix(redirectUrl, "/") {
params.Add("state", redirectUrl)
}
params.Add("state", redirectUrl)
return fmt.Sprintf("%s?%s", p.oauthLoginUrl, params.Encode())
}

Expand Down Expand Up @@ -227,16 +224,18 @@ func (p *OauthProxy) PingPage(rw http.ResponseWriter) {
fmt.Fprintf(rw, "OK")
}

func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
log.Printf("ErrorPage %d %s %s", code, title, message)
rw.WriteHeader(code)
templates := getTemplates()
t := struct {
Title string
Message string
Redirect string
}{
Title: fmt.Sprintf("%d %s", code, title),
Message: message,
Redirect: req.Form.Get("state"),
}
templates.ExecuteTemplate(rw, "error.html", t)
}
Expand All @@ -246,6 +245,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
rw.WriteHeader(code)
templates := getTemplates()

redirect := req.FormValue("rd")
if redirect == "" {
redirect = fmt.Sprintf("https://%s%s", req.Host, req.URL.RequestURI())
}

t := struct {
SignInMessage string
CustomLogin bool
Expand All @@ -254,9 +258,10 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
}{
SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(),
Redirect: req.URL.RequestURI(),
Redirect: redirect,
Version: VERSION,
}

templates.ExecuteTemplate(rw, "sign_in.html", t)
}

Expand Down Expand Up @@ -322,7 +327,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == signInPath {
redirect, err := p.GetRedirect(req)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
p.ErrorPage(rw, req, 500, "Internal Error", err.Error())
return
}

Expand All @@ -338,7 +343,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == oauthStartPath {
redirect, err := p.GetRedirect(req)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
p.ErrorPage(rw, req, 500, "Internal Error", err.Error())
return
}
http.Redirect(rw, req, p.GetLoginURL(redirect), 302)
Expand All @@ -348,19 +353,19 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// finish the oauth cycle
err := req.ParseForm()
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
p.ErrorPage(rw, req, 500, "Internal Error", err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
p.ErrorPage(rw, 403, "Permission Denied", errorString)
p.ErrorPage(rw, req, 403, "Permission Denied", errorString)
return
}

_, email, err := p.redeemCode(req.Form.Get("code"))
if err != nil {
log.Printf("%s error redeeming code %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", err.Error())
p.ErrorPage(rw, req, 500, "Internal Error", err.Error())
return
}

Expand All @@ -376,7 +381,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, redirect, 302)
return
} else {
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
p.ErrorPage(rw, req, 403, "Permission Denied", "Invalid Account")
return
}
}
Expand All @@ -389,8 +394,11 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}

authedByBasicAuth := false

if !ok {
user, ok = p.CheckBasicAuth(req)
authedByBasicAuth = ok
// if we want to promote basic auth requests to cookie'd requests, we could do that here
// not sure that would be ideal in all circumstances though
// if ok {
Expand All @@ -406,7 +414,11 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

// At this point, the user is authenticated. proxy normally
if p.PassBasicAuth {
req.SetBasicAuth(user, "")
if (authedByBasicAuth) {
// Strip the password if the basic auth was used to identify the google_auth_proxy user
// otherwise, just pass the basic auth information along.
req.SetBasicAuth(user, "")
}
req.Header["X-Forwarded-User"] = []string{user}
req.Header["X-Forwarded-Email"] = []string{email}
}
Expand Down
2 changes: 1 addition & 1 deletion templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func getTemplates() *template.Template {
<h2>{{.Title}}</h2>
<p>{{.Message}}</p>
<hr>
<p><a href="/oauth2/sign_in">Sign In</a></p>
<p><a href="/oauth2/sign_in?rd={{.Redirect}}">Sign In</a></p>
</body>
</html>{{end}}`)
if err != nil {
Expand Down
117 changes: 117 additions & 0 deletions websocket_reverse_proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package main

import (
"bufio"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
)

type WebsocketReverseProxy struct {
Proxy *httputil.ReverseProxy
Upstream string
}

func NewWebsocketReverseProxy(target *url.URL) *WebsocketReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(target)
return &WebsocketReverseProxy{Proxy: proxy, Upstream: target.Host}
}

func (p *WebsocketReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if websocketUpgradeRequest(req) {
p.hijackWebsocket(rw, req)
} else {
p.Proxy.ServeHTTP(rw, req)
}
}

func (p *WebsocketReverseProxy) hijackWebsocket(rw http.ResponseWriter, req *http.Request) {
highjacker, ok := rw.(http.Hijacker)

if !ok {
http.Error(rw, "webserver doesn't support hijacking", http.StatusInternalServerError)
return
}

conn, bufrw, err := highjacker.Hijack()
defer conn.Close()

conn2, err := net.Dial("tcp", p.Upstream)
if err != nil {
log.Printf("couldn't connect to backend websocket server: %v", err)
http.Error(rw, "couldn't connect to backend server", http.StatusServiceUnavailable)
return
}
defer conn2.Close()

err = req.Write(conn2)
if err != nil {
log.Printf("writing WebSocket request to backend server failed: %v", err)
return
}

bufferedBidirCopy(conn, bufrw, conn2, bufio.NewReadWriter(bufio.NewReader(conn2), bufio.NewWriter(conn2)))
}

func websocketUpgradeRequest(req *http.Request) bool {
connection_headers, ok := req.Header["Connection"]
if !ok || len(connection_headers) <= 0 {
return false
}

connection_header := connection_headers[0]
if strings.ToLower(connection_header) != "upgrade" {
return false
}

upgrade_headers, ok := req.Header["Upgrade"]
if !ok || len(upgrade_headers) <= 0 {
return false
}

return strings.ToLower(upgrade_headers[0]) == "websocket"
}

func bufferedCopy(dest *bufio.ReadWriter, src *bufio.ReadWriter) {
buf := make([]byte, 40*1024)
for {
n, err := src.Read(buf)
if err != nil && err != io.EOF {
log.Printf("Upstream read failed: %v", err)
return
}
if n == 0 {
return
}
n, err = dest.Write(buf[0:n])
if err != nil && err != io.EOF {
log.Printf("Downstream write failed: %v", err)
return
}

err = dest.Flush()
if err != nil {
log.Printf("Downstream write flush failed: %v", err)
return
}
}
}

func bufferedBidirCopy(conn1 io.ReadWriteCloser, rw1 *bufio.ReadWriter, conn2 io.ReadWriteCloser, rw2 *bufio.ReadWriter) {
wg := sync.WaitGroup{}

copier := func(wg *sync.WaitGroup, rw1 *bufio.ReadWriter, rw2 *bufio.ReadWriter) {
defer wg.Done()
bufferedCopy(rw2, rw1)
}

wg.Add(2)
go copier(&wg, rw1, rw2)
go copier(&wg, rw2, rw1)
wg.Wait()
}