Skip to content

Commit

Permalink
Merge branch 'main' into benkraft.generate-test
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminjkraft authored Nov 30, 2024
2 parents 3b1453c + 800909d commit a8153a6
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 9 deletions.
20 changes: 18 additions & 2 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,41 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
return newClient(endpoint, httpClient, http.MethodGet)
}

type WebSocketOption func(*webSocketClient)

// NewClientUsingWebSocket returns a [WebSocketClient] which makes subscription requests
// to the given endpoint using webSocket.
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
return &webSocketClient{
client := &webSocketClient{
Dialer: wsDialer,
Header: headers,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
}

for _, opt := range opts {
opt(client)
}

return client
}

// WithConnectionParams sets up connection params to be sent to the server
// during the initial connection handshake.
func WithConnectionParams(connParams map[string]interface{}) WebSocketOption {
return func(ws *webSocketClient) {
ws.connParams = connParams
}
}

func newClient(endpoint string, httpClient Doer, method string) Client {
Expand Down
11 changes: 9 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,18 @@ type webSocketClient struct {
Header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
errChan chan error
subscriptions subscriptionMap
isClosing bool
sync.Mutex
}

type webSocketInitMessage struct {
Payload map[string]interface{} `json:"payload"`
Type string `json:"type"`
}

type webSocketSendMessage struct {
Payload *Request `json:"payload"`
Type string `json:"type"`
Expand All @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct {
}

func (w *webSocketClient) sendInit() error {
connInitMsg := webSocketSendMessage{
Type: webSocketTypeConnInit,
connInitMsg := webSocketInitMessage{
Type: webSocketTypeConnInit,
Payload: w.connParams,
}
return w.sendStructAsJSON(connInitMsg)
}
Expand Down
56 changes: 56 additions & 0 deletions internal/integration/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func TestSubscription(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -146,6 +147,93 @@ func TestSubscription(t *testing.T) {
}
}

func TestSubscriptionConnectionParams(t *testing.T) {
_ = `# @genqlient
subscription countAuthorized { countAuthorized }`

authKey := server.AuthKey

ctx := context.Background()
server := server.RunServer()
defer server.Close()

cases := []struct {
connParams map[string]interface{}
name string
expectedError string
opts []graphql.WebSocketOption
}{
{
name: "authorized_user_gets_counter",
opts: []graphql.WebSocketOption{
graphql.WithConnectionParams(map[string]interface{}{
authKey: "authorized-user-token",
}),
},
},
{
name: "unauthorized_user_gets_error",
expectedError: "input: countAuthorized unauthorized\n",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(
t,
server.URL,
tc.opts...,
)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

dataChan, subscriptionID, err := countAuthorized(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()

var (
counter = 0
start = time.Now()
)

for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
loop = false
break
}

if tc.expectedError != "" {
require.Error(t, resp.Errors)
assert.Equal(t, tc.expectedError, resp.Errors.Error())
continue
}

require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.CountAuthorized)
require.Nil(t, resp.Errors)

if time.Since(start) > 5*time.Second {
err := wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
loop = false
}

counter++

case err := <-errChan:
require.NoError(t, err)

case <-time.After(10 * time.Second):
require.NoError(t, fmt.Errorf("subscription timed out"))
}
}
})
}
}

func TestServerError(t *testing.T) {
_ = `# @genqlient
query failingQuery { fail me { id } }`
Expand Down
12 changes: 9 additions & 3 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,20 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql.WebSocketOption) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}

return &roundtripClient{
wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil),
t: t,
wsWrapped: graphql.NewClientUsingWebSocket(
endpoint,
&MyDialer{Dialer: dialer},
nil,
opts...,
),
t: t,
}
}
1 change: 1 addition & 0 deletions internal/integration/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Mutation {

type Subscription {
count: Int!
countAuthorized: Int!
}

type User implements Being & Lucky {
Expand Down
72 changes: 71 additions & 1 deletion internal/integration/server/gqlgen_exec.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a8153a6

Please sign in to comment.