From 09f0599a6ecfec2652e01163283ee6b5a79a4c4f Mon Sep 17 00:00:00 2001 From: Gurvinder Singh Date: Thu, 2 Mar 2017 21:44:57 +0100 Subject: [PATCH] added support for websocket proxying. the support is mainly based on PR from soellman on bitly oauth2_proxy https://github.com/bitly/oauth2_proxy/pull/201 --- glide.lock | 14 ++++- glide.yaml | 3 +- main.go | 4 +- proxy.go | 27 +++++++-- websocket.go | 139 ++++++++++++++++++++++++++++++++++++++++++++++ websocket_test.go | 47 ++++++++++++++++ 6 files changed, 224 insertions(+), 10 deletions(-) create mode 100644 websocket.go create mode 100644 websocket_test.go diff --git a/glide.lock b/glide.lock index eefed7b..3740611 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: cc9d72a7b674a370888b1eddd771f86572c2ee4cbece85590fa8147bb0072b0b -updated: 2017-03-02T09:29:36.314575617+01:00 +hash: 029d04ba28feaea320a8b025603504b860f6f7af2fa57904af73989757c547aa +updated: 2017-03-02T19:36:55.571492373+01:00 imports: - name: github.com/coreos/go-oidc version: f828b1fc9b58b59bd70ace766bfc190216b58b01 @@ -13,6 +13,8 @@ imports: version: 69b215d01a5606c843240eab4937eab3acee6530 subpackages: - proto +- name: github.com/gorilla/websocket + version: 4873052237e4eeda85cf50c071ef33836fe8e139 - name: github.com/hashicorp/hcl version: 630949a3c5fa3c613328e1b8256052cbc2327c9b subpackages: @@ -37,7 +39,7 @@ imports: - name: github.com/pelletier/go-buffruneio version: c37440a7cf42ac63b919c752ca73a85067e05992 - name: github.com/pelletier/go-toml - version: 361678322880708ac144df8575e6f01144ba1404 + version: 13d49d4606eb801b8f01ae542b4afc4c6ee3d84a - name: github.com/pkg/errors version: bfd5150e4e41705ded2129ec33379de1cb90b513 - name: github.com/pquerna/cachecontrol @@ -101,6 +103,12 @@ imports: - name: gopkg.in/yaml.v2 version: a3f3340b5840cee44f372bddb5880fcbc419b46a testImports: +- name: github.com/bmizerany/assert + version: b7ed37b82869576c289d7d97fb2bbd8b64a0cb28 +- name: github.com/kr/pretty + version: cfb55aafdaf3ec08f0db22699ab822c50091b1c4 +- name: github.com/kr/text + version: 7cafcd837844e784b526369c9bce262804aebc60 - name: github.com/pmezard/go-difflib version: d8ed2627bdf02c080bf22230dbb337003b7aba2d subpackages: diff --git a/glide.yaml b/glide.yaml index d8faf8f..caeaea5 100644 --- a/glide.yaml +++ b/glide.yaml @@ -7,4 +7,5 @@ import: - package: golang.org/x/oauth2 - package: github.com/m4rw3r/uuid - package: github.com/parnurzeal/gorequest -- package: github.com/SermoDigital/jose/jws \ No newline at end of file +- package: github.com/SermoDigital/jose/jws +- package: github.com/gorilla/websocket \ No newline at end of file diff --git a/main.go b/main.go index 2ca41be..bf65453 100644 --- a/main.go +++ b/main.go @@ -92,7 +92,7 @@ func main() { } // Create proxy and middleware - target := NewReverseProxy(targetURL) + upstream := NewUpstreamProxy(targetURL) authn, err := newAuthenticator( conf.GetStringValue("engine.client_id"), conf.GetStringValue("engine.client_secret"), @@ -106,7 +106,7 @@ func main() { // Configure routes http.Handle("/healthz", healthzHandler(targetURL.String())) http.Handle("/oauth2/callback", authn.callbackHandler()) - http.Handle("/", authn.authHandler(target)) + http.Handle("/", authn.authHandler(upstream)) // Start proxying log.Println("Proxy initialized and listening on port", conf.GetIntValue("server.port")) diff --git a/proxy.go b/proxy.go index bd61467..c781b16 100644 --- a/proxy.go +++ b/proxy.go @@ -3,9 +3,6 @@ package main import ( "bytes" "encoding/json" - log "github.com/Sirupsen/logrus" - "github.com/uninett/goidc-proxy/conf" - "golang.org/x/oauth2" "io/ioutil" "net" "net/http" @@ -13,6 +10,10 @@ import ( "net/url" "strings" "time" + + log "github.com/Sirupsen/logrus" + "github.com/uninett/goidc-proxy/conf" + "golang.org/x/oauth2" ) type transport struct { @@ -23,6 +24,11 @@ type ACRValues struct { Values string `json:"required_acr_values"` } +type UpstreamProxy struct { + upstream *url.URL + handler http.Handler +} + func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") @@ -74,8 +80,21 @@ func (t *transport) RoundTrip(req *http.Request) (resp *http.Response, err error return resp, nil } +func NewUpstreamProxy(target *url.URL) *UpstreamProxy { + proxy := newReverseProxy(target) + return &UpstreamProxy{target, proxy} +} + +func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if isWebsocketRequest(r) { + u.handleWebsocket(w, r) + } else { + u.handler.ServeHTTP(w, r) + } +} + // NewReverseProxy prvoides reverse proxy functionality towards target -func NewReverseProxy(target *url.URL) *httputil.ReverseProxy { +func newReverseProxy(target *url.URL) *httputil.ReverseProxy { director := func(req *http.Request) { req.URL.Scheme = target.Scheme req.URL.Host = target.Host diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..d13f9cf --- /dev/null +++ b/websocket.go @@ -0,0 +1,139 @@ +package main + +import ( + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + log "github.com/Sirupsen/logrus" + "github.com/gorilla/websocket" +) + +var ( + ConnectionHeaderKey = http.CanonicalHeaderKey("connection") + SetCookieHeaderKey = http.CanonicalHeaderKey("set-cookie") + UpgradeHeaderKey = http.CanonicalHeaderKey("upgrade") + WSKeyHeaderKey = http.CanonicalHeaderKey("sec-websocket-key") + WSVersionHeaderKey = http.CanonicalHeaderKey("sec-websocket-version") + WSProtocolHeaderKey = http.CanonicalHeaderKey("sec-websocket-protocol") + WSExtensionHeaderKey = http.CanonicalHeaderKey("sec-websocket-extensions") + + ConnectionHeaderValue = "Upgrade" + UpgradeHeaderValue = "websocket" + + HandshakeHeaders = []string{ConnectionHeaderKey, UpgradeHeaderKey, WSVersionHeaderKey, WSKeyHeaderKey, WSExtensionHeaderKey} + UpgradeHeaders = []string{SetCookieHeaderKey, WSProtocolHeaderKey} +) + +func (u *UpstreamProxy) handleWebsocket(w http.ResponseWriter, r *http.Request) { + + // Copy request headers and remove websocket handshaking headers + // before submitting to the upstream server + upstreamHeader := http.Header{} + for key, _ := range r.Header { + copyHeader(&upstreamHeader, r.Header, key) + } + for _, header := range HandshakeHeaders { + delete(upstreamHeader, header) + } + upstreamHeader.Set("Host", r.Host) + + // Connect upstream + upstreamAddr := u.upstreamWSURL(*r.URL).String() + upstream, upstreamResp, err := websocket.DefaultDialer.Dial(upstreamAddr, upstreamHeader) + if err != nil { + if upstreamResp != nil { + log.Warn("dialing upstream websocket failed with code %d: %v", upstreamResp.StatusCode, err) + } else { + log.Warn("dialing upstream websocket failed: %v", err) + } + http.Error(w, "websocket unavailable", http.StatusServiceUnavailable) + return + } + defer upstream.Close() + + // Pass websocket handshake response headers to the upgrader + upgradeHeader := http.Header{} + copyHeaders(&upgradeHeader, upstreamResp.Header, UpgradeHeaders) + + // Upgrade the client connection without validating the origin + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + client, err := upgrader.Upgrade(w, r, upgradeHeader) + if err != nil { + log.Printf("couldn't upgrade websocket request: %v", err) + http.Error(w, "websocket upgrade failed", http.StatusServiceUnavailable) + return + } + + // Wire both sides together and close when finished + var wg sync.WaitGroup + cp := func(dst, src *websocket.Conn) { + defer wg.Done() + _, err := io.Copy(dst.UnderlyingConn(), src.UnderlyingConn()) + + var closeMessage []byte + if err != nil { + closeMessage = websocket.FormatCloseMessage(websocket.CloseProtocolError, err.Error()) + } else { + closeMessage = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye") + } + // Attempt to close the connection properly + dst.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) + src.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) + } + wg.Add(2) + go cp(upstream, client) + go cp(client, upstream) + wg.Wait() +} + +// Create a websocket URL from the request URL +func (u *UpstreamProxy) upstreamWSURL(r url.URL) *url.URL { + ws := r + ws.User = r.User + ws.Host = u.upstream.Host + ws.Fragment = "" + switch u.upstream.Scheme { + case "http": + ws.Scheme = "ws" + case "https": + ws.Scheme = "wss" + } + return &ws +} + +func isWebsocketRequest(req *http.Request) bool { + return isHeaderValuePresent(req.Header, UpgradeHeaderKey, UpgradeHeaderValue) && + isHeaderValuePresent(req.Header, ConnectionHeaderKey, ConnectionHeaderValue) +} + +func isHeaderValuePresent(headers http.Header, key string, value string) bool { + for _, header := range headers[key] { + for _, v := range strings.Split(header, ",") { + if strings.EqualFold(value, strings.TrimSpace(v)) { + return true + } + } + } + return false +} + +func copyHeaders(dst *http.Header, src http.Header, headers []string) { + for _, header := range headers { + copyHeader(dst, src, header) + } +} + +// Copy any non-empty and non-blank header values +func copyHeader(dst *http.Header, src http.Header, header string) { + for _, value := range src[header] { + if value != "" { + dst.Add(header, value) + } + } +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..158b32a --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "net/http" + "testing" + + "github.com/bmizerany/assert" +) + +func TestCopyHeader(t *testing.T) { + src := http.Header{ + "EmptyValue": []string{""}, + "Nil": []string{}, + "Single": []string{"one"}, + "Multi": []string{"one", "two"}, + } + expected := http.Header{ + "Single": []string{"one"}, + "Multi": []string{"one", "two"}, + } + dst := http.Header{} + for key, _ := range src { + copyHeader(&dst, src, key) + } + assert.Equal(t, expected, dst) +} + +func TestUpgrade(t *testing.T) { + tests := []struct { + upgrade bool + connectionValue string + upgradeValue string + }{ + {true, "Upgrade", "Websocket"}, + {true, "keepalive, Upgrade", "websocket"}, + {false, "", "websocket"}, + {false, "keepalive, Upgrade", ""}, + } + + for _, tt := range tests { + req := new(http.Request) + req.Header = http.Header{} + req.Header.Set(ConnectionHeaderKey, tt.connectionValue) + req.Header.Set(UpgradeHeaderKey, tt.upgradeValue) + assert.Equal(t, tt.upgrade, isWebsocketRequest(req)) + } +}