diff --git a/README.md b/README.md index 3c0f98fe9..563708e50 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,7 @@ See below for provider specific options ### Upstreams Configuration -`oauth2_proxy` supports having multiple upstreams, and has the option to pass requests on to HTTP(S) servers or serve static files from the file system. HTTP and HTTPS upstreams are configured by providing a URL such as `http://127.0.0.1:8080/` for the upstream parameter, that will forward all authenticated requests to be forwarded to the upstream server. If you instead provide `http://127.0.0.1:8080/some/path/` then it will only be requests that start with `/some/path/` which are forwarded to the upstream. +`oauth2_proxy` supports having multiple upstreams, and has the option to pass requests on to HTTP(S) servers or serve static files from the file system. HTTP and HTTPS upstreams are configured by providing a URL such as `http://127.0.0.1:8080/` for the upstream parameter, that will forward all authenticated requests to be forwarded to the upstream server. If you instead provide `http://127.0.0.1:8080/some/path/` then it will only be requests that start with `/some/path/` which are forwarded to the upstream. Websocket requests are proxied transparently to HTTP and HTTPS upstreams. Static file paths are configured as a file:// URL. `file:///var/www/static/` will serve the files from that directory at `http://[oauth2_proxy url]/var/www/static/`, which may not be what you want. You can provide the path to where the files should be available by adding a fragment to the configured URL. The value of the fragment will then be used to specify which path the files are available at. `file:///var/www/static/#/static/` will ie. make `/var/www/static/` available at `http://[oauth2_proxy url]/static/`. diff --git a/logging_handler.go b/logging_handler.go index 17fca977b..c36048189 100644 --- a/logging_handler.go +++ b/logging_handler.go @@ -4,6 +4,8 @@ package main import ( + "bufio" + "errors" "fmt" "io" "net" @@ -26,6 +28,16 @@ func (l *responseLogger) Header() http.Header { return l.w.Header() } +func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := l.w.(http.Hijacker) + + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + + return hijacker.Hijack() +} + func (l *responseLogger) ExtractGAPMetadata() { upstream := l.w.Header().Get("GAP-Upstream-Address") if upstream != "" { diff --git a/oauthproxy.go b/oauthproxy.go index 16adf2249..e5c1f0c89 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -68,18 +68,22 @@ type OAuthProxy struct { } type UpstreamProxy struct { - upstream string + upstream url.URL handler http.Handler auth hmacauth.HmacAuth } func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("GAP-Upstream-Address", u.upstream) + w.Header().Set("GAP-Upstream-Address", u.upstream.Host) if u.auth != nil { r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) u.auth.SignRequest(r) } - u.handler.ServeHTTP(w, r) + if isWebsocketRequest(r) { + u.handleWebsocket(w, r) + } else { + u.handler.ServeHTTP(w, r) + } } func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { @@ -128,14 +132,14 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { setProxyDirector(proxy) } serveMux.Handle(path, - &UpstreamProxy{u.Host, proxy, auth}) + &UpstreamProxy{*u, proxy, auth}) case "file": if u.Fragment != "" { path = u.Fragment } log.Printf("mapping path %q => file system %q", path, u.Path) proxy := NewFileServer(path, u.Path) - serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) + serveMux.Handle(path, &UpstreamProxy{*u, proxy, nil}) default: panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) } diff --git a/websocket.go b/websocket.go new file mode 100644 index 000000000..21995648d --- /dev/null +++ b/websocket.go @@ -0,0 +1,138 @@ +package main + +import ( + "io" + "log" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var ( + ConnectionHeaderKey = http.CanonicalHeaderKey("connection") + SetCookieHeaderKey = http.CanonicalHeaderKey("set-cookie") + UpgradeHeaderKey = http.CanonicalHeaderKey("upgrade") + WSKeyHeaderKey = http.CanonicalHeaderKey("sec-websocket-key") + WSProtocolHeaderKey = http.CanonicalHeaderKey("sec-websocket-protocol") + WSVersionHeaderKey = http.CanonicalHeaderKey("sec-websocket-version") + + ConnectionHeaderValue = "Upgrade" + UpgradeHeaderValue = "websocket" + + HandshakeHeaders = []string{ConnectionHeaderKey, UpgradeHeaderKey, WSVersionHeaderKey, WSKeyHeaderKey} + UpgradeHeaders = []string{SetCookieHeaderKey, WSProtocolHeaderKey} +) + +func (u *UpstreamProxy) handleWebsocket(w http.ResponseWriter, r *http.Request) { + + // Copy request headers and remove websocket handshaking headers + // before submitting to the upstream server + upstreamHeader := http.Header{} + for key, _ := range r.Header { + copyHeader(&upstreamHeader, r.Header, key) + } + for _, header := range HandshakeHeaders { + delete(upstreamHeader, header) + } + upstreamHeader.Set("Host", r.Host) + + // Connect upstream + upstreamAddr := u.upstreamWSURL(*r.URL).String() + upstream, upstreamResp, err := websocket.DefaultDialer.Dial(upstreamAddr, upstreamHeader) + if err != nil { + if upstreamResp != nil { + log.Printf("dialing upstream websocket failed with code %d: %v", upstreamResp.StatusCode, err) + } else { + log.Printf("dialing upstream websocket failed: %v", err) + } + http.Error(w, "websocket unavailable", http.StatusServiceUnavailable) + return + } + defer upstream.Close() + + // Pass websocket handshake response headers to the upgrader + upgradeHeader := http.Header{} + copyHeaders(&upgradeHeader, upstreamResp.Header, UpgradeHeaders) + + // Upgrade the client connection without validating the origin + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + client, err := upgrader.Upgrade(w, r, upgradeHeader) + if err != nil { + log.Printf("couldn't upgrade websocket request: %v", err) + http.Error(w, "websocket upgrade failed", http.StatusServiceUnavailable) + return + } + + // Wire both sides together and close when finished + var wg sync.WaitGroup + cp := func(dst, src *websocket.Conn) { + defer wg.Done() + _, err := io.Copy(dst.UnderlyingConn(), src.UnderlyingConn()) + + var closeMessage []byte + if err != nil { + closeMessage = websocket.FormatCloseMessage(websocket.CloseProtocolError, err.Error()) + } else { + closeMessage = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye") + } + // Attempt to close the connection properly + dst.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) + src.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) + } + wg.Add(2) + go cp(upstream, client) + go cp(client, upstream) + wg.Wait() +} + +// Create a websocket URL from the request URL +func (u *UpstreamProxy) upstreamWSURL(r url.URL) *url.URL { + ws := r + ws.User = r.User + ws.Host = u.upstream.Host + ws.Fragment = "" + switch u.upstream.Scheme { + case "http": + ws.Scheme = "ws" + case "https": + ws.Scheme = "wss" + } + return &ws +} + +func isWebsocketRequest(req *http.Request) bool { + return isHeaderValuePresent(req.Header, UpgradeHeaderKey, UpgradeHeaderValue) && + isHeaderValuePresent(req.Header, ConnectionHeaderKey, ConnectionHeaderValue) +} + +func isHeaderValuePresent(headers http.Header, key string, value string) bool { + for _, header := range headers[key] { + for _, v := range strings.Split(header, ",") { + if strings.EqualFold(value, strings.TrimSpace(v)) { + return true + } + } + } + return false +} + +func copyHeaders(dst *http.Header, src http.Header, headers []string) { + for _, header := range headers { + copyHeader(dst, src, header) + } +} + +// Copy any non-empty and non-blank header values +func copyHeader(dst *http.Header, src http.Header, header string) { + for _, value := range src[header] { + if value != "" { + dst.Add(header, value) + } + } +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 000000000..158b32a26 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "net/http" + "testing" + + "github.com/bmizerany/assert" +) + +func TestCopyHeader(t *testing.T) { + src := http.Header{ + "EmptyValue": []string{""}, + "Nil": []string{}, + "Single": []string{"one"}, + "Multi": []string{"one", "two"}, + } + expected := http.Header{ + "Single": []string{"one"}, + "Multi": []string{"one", "two"}, + } + dst := http.Header{} + for key, _ := range src { + copyHeader(&dst, src, key) + } + assert.Equal(t, expected, dst) +} + +func TestUpgrade(t *testing.T) { + tests := []struct { + upgrade bool + connectionValue string + upgradeValue string + }{ + {true, "Upgrade", "Websocket"}, + {true, "keepalive, Upgrade", "websocket"}, + {false, "", "websocket"}, + {false, "keepalive, Upgrade", ""}, + } + + for _, tt := range tests { + req := new(http.Request) + req.Header = http.Header{} + req.Header.Set(ConnectionHeaderKey, tt.connectionValue) + req.Header.Set(UpgradeHeaderKey, tt.upgradeValue) + assert.Equal(t, tt.upgrade, isWebsocketRequest(req)) + } +}