From d3e516b53f5fa691591ba3062fd52f58b3dfe6a9 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Mon, 2 Dec 2024 06:21:21 +0100 Subject: [PATCH] Move websocket headers to opt function 'WithWebsocketHeaders' (#365) Follow-up up on the discussion in https://github.com/Khan/genqlient/pull/360#pullrequestreview-2471038655. Move websocket headers to and opt function 'WithWebsocketHeaders'. Note this is a breaking change for users using the main branch (but not for users on tagged releases). I have: - [x] Written a clear PR title and description (above) - [x] Signed the [Khan Academy CLA](https://www.khanacademy.org/r/cla) - [x] Added tests covering my changes, if applicable - [x] Included a link to the issue fixed, if applicable - [x] Included documentation, for new features - [x] Added an entry to the changelog --- graphql/client.go | 21 +++++++++++++-------- graphql/websocket.go | 4 ++-- internal/integration/integration_test.go | 10 +++++++++- internal/integration/roundtrip.go | 1 - internal/integration/server/server.go | 20 +++++++++++++++++++- 5 files changed, 43 insertions(+), 13 deletions(-) diff --git a/graphql/client.go b/graphql/client.go index 334f1e6..8278759 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -133,16 +133,10 @@ type WebSocketOption func(*webSocketClient) // // The client does not support queries nor mutations, and will return an error // if passed a request that attempts one. -func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient { - if headers == nil { - headers = http.Header{} - } - if headers.Get("Sec-WebSocket-Protocol") == "" { - headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") - } +func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocketOption) WebSocketClient { client := &webSocketClient{ Dialer: wsDialer, - Header: headers, + header: http.Header{}, errChan: make(chan error), endpoint: endpoint, subscriptions: subscriptionMap{map_: make(map[string]subscription)}, @@ -152,6 +146,10 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head opt(client) } + if client.header.Get("Sec-WebSocket-Protocol") == "" { + client.header.Add("Sec-WebSocket-Protocol", "graphql-transport-ws") + } + return client } @@ -163,6 +161,13 @@ func WithConnectionParams(connParams map[string]interface{}) WebSocketOption { } } +// WithWebsocketHeader sets a header to be sent to the server. +func WithWebsocketHeader(header http.Header) WebSocketOption { + return func(ws *webSocketClient) { + ws.header = header + } +} + func newClient(endpoint string, httpClient Doer, method string) Client { if httpClient == nil || httpClient == (*http.Client)(nil) { httpClient = http.DefaultClient diff --git a/graphql/websocket.go b/graphql/websocket.go index 5fb97ce..0e381ec 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -45,7 +45,7 @@ const ( type webSocketClient struct { Dialer Dialer - Header http.Header + header http.Header endpoint string conn WSConn connParams map[string]interface{} @@ -169,7 +169,7 @@ func checkConnectionAckReceived(message []byte) (bool, error) { } func (w *webSocketClient) Start(ctx context.Context) (errChan chan error, err error) { - w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header) + w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.header) if err != nil { return nil, err } diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 9aa2166..042c20f 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -164,13 +164,21 @@ func TestSubscriptionConnectionParams(t *testing.T) { opts []graphql.WebSocketOption }{ { - name: "authorized_user_gets_counter", + name: "connection_params_authorized_user_gets_counter", opts: []graphql.WebSocketOption{ graphql.WithConnectionParams(map[string]interface{}{ authKey: "authorized-user-token", }), }, }, + { + name: "http_header_authorized_user_gets_counter", + opts: []graphql.WebSocketOption{ + graphql.WithWebsocketHeader(http.Header{ + authKey: []string{"authorized-user-token"}, + }), + }, + }, { name: "unauthorized_user_gets_error", expectedError: "input: countAuthorized unauthorized\n", diff --git a/internal/integration/roundtrip.go b/internal/integration/roundtrip.go index 835fa69..f7a4246 100644 --- a/internal/integration/roundtrip.go +++ b/internal/integration/roundtrip.go @@ -167,7 +167,6 @@ func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql. wsWrapped: graphql.NewClientUsingWebSocket( endpoint, &MyDialer{Dialer: dialer}, - nil, opts..., ), t: t, diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index ff996a3..71e08ad 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "net/http" "net/http/httptest" "strconv" "time" @@ -198,6 +199,20 @@ func getAuthToken(ctx context.Context) string { return "" } +func authHeaderMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + token := r.Header.Get(AuthKey) + if token != "" { + ctx = withAuthToken(ctx, token) + } + + r = r.WithContext(ctx) + handler.ServeHTTP(w, r) + }) +} + func RunServer() *httptest.Server { gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}})) gqlgenServer.AddTransport(transport.POST{}) @@ -216,7 +231,10 @@ func RunServer() *httptest.Server { graphql.RegisterExtension(ctx, "foobar", "test") return next(ctx) }) - return httptest.NewServer(gqlgenServer) + + server := authHeaderMiddleware(gqlgenServer) + + return httptest.NewServer(server) } type (