Skip to content

Commit

Permalink
added support for websocket proxying. the support is mainly based on …
Browse files Browse the repository at this point in the history
…PR from soellman on bitly oauth2_proxy

bitly/oauth2_proxy#201
  • Loading branch information
Gurvinder Singh committed Mar 2, 2017
1 parent f183365 commit 09f0599
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 10 deletions.
14 changes: 11 additions & 3 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion glide.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import:
- package: golang.org/x/oauth2
- package: github.com/m4rw3r/uuid
- package: github.com/parnurzeal/gorequest
- package: github.com/SermoDigital/jose/jws
- package: github.com/SermoDigital/jose/jws
- package: github.com/gorilla/websocket
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func main() {
}

// Create proxy and middleware
target := NewReverseProxy(targetURL)
upstream := NewUpstreamProxy(targetURL)
authn, err := newAuthenticator(
conf.GetStringValue("engine.client_id"),
conf.GetStringValue("engine.client_secret"),
Expand All @@ -106,7 +106,7 @@ func main() {
// Configure routes
http.Handle("/healthz", healthzHandler(targetURL.String()))
http.Handle("/oauth2/callback", authn.callbackHandler())
http.Handle("/", authn.authHandler(target))
http.Handle("/", authn.authHandler(upstream))

// Start proxying
log.Println("Proxy initialized and listening on port", conf.GetIntValue("server.port"))
Expand Down
27 changes: 23 additions & 4 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package main
import (
"bytes"
"encoding/json"
log "github.com/Sirupsen/logrus"
"github.com/uninett/goidc-proxy/conf"
"golang.org/x/oauth2"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"

log "github.com/Sirupsen/logrus"
"github.com/uninett/goidc-proxy/conf"
"golang.org/x/oauth2"
)

type transport struct {
Expand All @@ -23,6 +24,11 @@ type ACRValues struct {
Values string `json:"required_acr_values"`
}

type UpstreamProxy struct {
upstream *url.URL
handler http.Handler
}

func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
Expand Down Expand Up @@ -74,8 +80,21 @@ func (t *transport) RoundTrip(req *http.Request) (resp *http.Response, err error
return resp, nil
}

func NewUpstreamProxy(target *url.URL) *UpstreamProxy {
proxy := newReverseProxy(target)
return &UpstreamProxy{target, proxy}
}

func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if isWebsocketRequest(r) {
u.handleWebsocket(w, r)
} else {
u.handler.ServeHTTP(w, r)
}
}

// NewReverseProxy prvoides reverse proxy functionality towards target
func NewReverseProxy(target *url.URL) *httputil.ReverseProxy {
func newReverseProxy(target *url.URL) *httputil.ReverseProxy {
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
Expand Down
139 changes: 139 additions & 0 deletions websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package main

import (
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"

log "github.com/Sirupsen/logrus"
"github.com/gorilla/websocket"
)

var (
ConnectionHeaderKey = http.CanonicalHeaderKey("connection")
SetCookieHeaderKey = http.CanonicalHeaderKey("set-cookie")
UpgradeHeaderKey = http.CanonicalHeaderKey("upgrade")
WSKeyHeaderKey = http.CanonicalHeaderKey("sec-websocket-key")
WSVersionHeaderKey = http.CanonicalHeaderKey("sec-websocket-version")
WSProtocolHeaderKey = http.CanonicalHeaderKey("sec-websocket-protocol")
WSExtensionHeaderKey = http.CanonicalHeaderKey("sec-websocket-extensions")

ConnectionHeaderValue = "Upgrade"
UpgradeHeaderValue = "websocket"

HandshakeHeaders = []string{ConnectionHeaderKey, UpgradeHeaderKey, WSVersionHeaderKey, WSKeyHeaderKey, WSExtensionHeaderKey}
UpgradeHeaders = []string{SetCookieHeaderKey, WSProtocolHeaderKey}
)

func (u *UpstreamProxy) handleWebsocket(w http.ResponseWriter, r *http.Request) {

// Copy request headers and remove websocket handshaking headers
// before submitting to the upstream server
upstreamHeader := http.Header{}
for key, _ := range r.Header {
copyHeader(&upstreamHeader, r.Header, key)
}
for _, header := range HandshakeHeaders {
delete(upstreamHeader, header)
}
upstreamHeader.Set("Host", r.Host)

// Connect upstream
upstreamAddr := u.upstreamWSURL(*r.URL).String()
upstream, upstreamResp, err := websocket.DefaultDialer.Dial(upstreamAddr, upstreamHeader)
if err != nil {
if upstreamResp != nil {
log.Warn("dialing upstream websocket failed with code %d: %v", upstreamResp.StatusCode, err)
} else {
log.Warn("dialing upstream websocket failed: %v", err)
}
http.Error(w, "websocket unavailable", http.StatusServiceUnavailable)
return
}
defer upstream.Close()

// Pass websocket handshake response headers to the upgrader
upgradeHeader := http.Header{}
copyHeaders(&upgradeHeader, upstreamResp.Header, UpgradeHeaders)

// Upgrade the client connection without validating the origin
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
client, err := upgrader.Upgrade(w, r, upgradeHeader)
if err != nil {
log.Printf("couldn't upgrade websocket request: %v", err)
http.Error(w, "websocket upgrade failed", http.StatusServiceUnavailable)
return
}

// Wire both sides together and close when finished
var wg sync.WaitGroup
cp := func(dst, src *websocket.Conn) {
defer wg.Done()
_, err := io.Copy(dst.UnderlyingConn(), src.UnderlyingConn())

var closeMessage []byte
if err != nil {
closeMessage = websocket.FormatCloseMessage(websocket.CloseProtocolError, err.Error())
} else {
closeMessage = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye")
}
// Attempt to close the connection properly
dst.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second))
src.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second))
}
wg.Add(2)
go cp(upstream, client)
go cp(client, upstream)
wg.Wait()
}

// Create a websocket URL from the request URL
func (u *UpstreamProxy) upstreamWSURL(r url.URL) *url.URL {
ws := r
ws.User = r.User
ws.Host = u.upstream.Host
ws.Fragment = ""
switch u.upstream.Scheme {
case "http":
ws.Scheme = "ws"
case "https":
ws.Scheme = "wss"
}
return &ws
}

func isWebsocketRequest(req *http.Request) bool {
return isHeaderValuePresent(req.Header, UpgradeHeaderKey, UpgradeHeaderValue) &&
isHeaderValuePresent(req.Header, ConnectionHeaderKey, ConnectionHeaderValue)
}

func isHeaderValuePresent(headers http.Header, key string, value string) bool {
for _, header := range headers[key] {
for _, v := range strings.Split(header, ",") {
if strings.EqualFold(value, strings.TrimSpace(v)) {
return true
}
}
}
return false
}

func copyHeaders(dst *http.Header, src http.Header, headers []string) {
for _, header := range headers {
copyHeader(dst, src, header)
}
}

// Copy any non-empty and non-blank header values
func copyHeader(dst *http.Header, src http.Header, header string) {
for _, value := range src[header] {
if value != "" {
dst.Add(header, value)
}
}
}
47 changes: 47 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"net/http"
"testing"

"github.com/bmizerany/assert"
)

func TestCopyHeader(t *testing.T) {
src := http.Header{
"EmptyValue": []string{""},
"Nil": []string{},
"Single": []string{"one"},
"Multi": []string{"one", "two"},
}
expected := http.Header{
"Single": []string{"one"},
"Multi": []string{"one", "two"},
}
dst := http.Header{}
for key, _ := range src {
copyHeader(&dst, src, key)
}
assert.Equal(t, expected, dst)
}

func TestUpgrade(t *testing.T) {
tests := []struct {
upgrade bool
connectionValue string
upgradeValue string
}{
{true, "Upgrade", "Websocket"},
{true, "keepalive, Upgrade", "websocket"},
{false, "", "websocket"},
{false, "keepalive, Upgrade", ""},
}

for _, tt := range tests {
req := new(http.Request)
req.Header = http.Header{}
req.Header.Set(ConnectionHeaderKey, tt.connectionValue)
req.Header.Set(UpgradeHeaderKey, tt.upgradeValue)
assert.Equal(t, tt.upgrade, isWebsocketRequest(req))
}
}

0 comments on commit 09f0599

Please sign in to comment.