Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CheckRedirect callback #269

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions client/httpclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package client
import (
"compress/gzip"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"

"github.com/open-telemetry/opamp-go/client/internal"
Expand Down Expand Up @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) {
// Shutdown the Server.
srv.Close()
}

func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock {
m := &checkRedirectMock{
t: t,
viaLen: viaLen,
http: true,
}
m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err)
return m
}

func TestRedirectHTTP(t *testing.T) {
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
MockRedirect *checkRedirectMock
}{
{
Name: "simple redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
},
{
Name: "check redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirectHTTP(t, 1, nil),
},
{
Name: "check redirect returns error",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")),
ExpError: true,
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var connectErr atomic.Value
var connected atomic.Value

settings := &types.StartSettings{
Callbacks: types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
},
}
if test.MockRedirect != nil {
settings.Callbacks = types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
CheckRedirect: test.MockRedirect.CheckRedirect,
}
}
reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil
settings.OpAMPServerURL = reURL.String()
client := NewHTTP(nil)
prepareClient(t, settings, client)

err := client.Start(context.Background(), *settings)
if err != nil {
t.Fatal(err)
}
defer client.Stop(context.Background())
// Wait for connection to be established.
eventually(t, func() bool {
return connected.Load() != nil || connectErr.Load() != nil
})
if test.ExpError && connectErr.Load() == nil {
t.Error("expected non-nil error")
} else if err := connectErr.Load(); !test.ExpError && err != nil {
t.Fatal(err)
}
})
}
}
8 changes: 8 additions & 0 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func (h *HTTPSender) Run(
h.callbacks = callbacks
h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex)

// we need to detect if the redirect was ever set, if not, we want default behaviour
if callbacks.CheckRedirect != nil {
h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// viaResp only non-nil for ws client
return callbacks.CheckRedirect(req, via, nil)
}
}

for {
pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs)))
select {
Expand Down
14 changes: 14 additions & 0 deletions client/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types

import (
"context"
"net/http"

"github.com/open-telemetry/opamp-go/protobufs"
)
Expand Down Expand Up @@ -110,6 +111,19 @@ type Callbacks struct {

// OnCommand is called when the Server requests that the connected Agent perform a command.
OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error

// CheckRedirect is called before following a redirect, allowing the client
// the opportunity to observe the redirect chain, and optionally terminate
// following redirects early.
//
// CheckRedirect is intended to be similar, although not exactly equivalent,
// to net/http.Client's CheckRedirect feature. Unlike in net/http, the via
// parameter is a slice of HTTP responses, instead of requests. This gives
// an opportunity to users to know what the exact response headers and
// status were. The request itself can be obtained from the response.
//
// The responses in the via parameter are passed with their bodies closed.
CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error
}

func (c *Callbacks) SetDefaults() {
Expand Down
88 changes: 74 additions & 14 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type wsClient struct {
// Network connection timeout used for the WebSocket closing handshake.
// This field is currently only modified during testing.
connShutdownTimeout time.Duration

// responseChain is used for the "via" argument in CheckRedirect.
// It is appended to with every redirect followed, and zeroed on a succesful
// connection. responseChain should only be referred to by the goroutine that
// runs tryConnectOnce and its synchronous callees.
responseChain []*http.Response
}

// NewWebSocket creates a new OpAMP Client that uses WebSocket transport.
Expand Down Expand Up @@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
return c.common.SendCustomMessage(message)
}

func viaReq(resps []*http.Response) []*http.Request {
reqs := make([]*http.Request, 0, len(resps))
for _, resp := range resps {
reqs = append(reqs, resp.Request)
}
return reqs
}

// handleRedirect checks a failed websocket upgrade response for a 3xx response
// and a Location header. If found, it sets the URL to the location found in the
// header so that it is tried on the next retry, instead of the current URL.
func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error {
// append to the responseChain so that subsequent redirects will have access
c.responseChain = append(c.responseChain, resp)

// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
return err
}

// It's slightly tricky to make CheckRedirect work. The WS HTTP request is
// formed within the websocket library. To work around that, copy the
// previous request, available in the response, and set the URL to the new
// location. It should then result in the same URL that the websocket
// library will form.
nextRequest := resp.Request.Clone(ctx)
nextRequest.URL = redirect

// if CheckRedirect results in an error, it gets returned, terminating
// redirection. As with stdlib, the error is wrapped in url.Error.
if c.common.Callbacks.CheckRedirect != nil {
if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil {
return &url.Error{
Op: "Get",
URL: nextRequest.URL.String(),
Err: err,
}
}
}

// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)

// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect

return nil
}

// Try to connect once. Returns an error if connection fails and optional retryAfter
// duration to indicate to the caller to retry after the specified time as instructed
// by the Server.
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
var resp *http.Response
var redirecting bool
defer func() {
if err != nil && !redirecting {
c.responseChain = nil
if !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
}
}
}()
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
if err != nil {
if !c.common.IsStopping() {
Expand All @@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna
if resp != nil {
duration := sharedinternal.ExtractRetryAfterHeader(resp)
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
// very liberal handling of 3xx that largely ignores HTTP semantics
redirect, err := resp.Location()
if err != nil {
c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err)
redirecting = true
if err := c.handleRedirect(ctx, resp); err != nil {
return duration, err
}
// rewrite the scheme for the sake of tolerance
if redirect.Scheme == "http" {
redirect.Scheme = "ws"
} else if redirect.Scheme == "https" {
redirect.Scheme = "wss"
}
c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect)
// Set the URL to the redirect, so that it connects to it on the
// next cycle.
c.url = redirect
} else {
c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status)
}
Expand Down
Loading
Loading