Skip to content

Commit

Permalink
Implement CheckRedirect for HTTP
Browse files Browse the repository at this point in the history
This commit adds support for a CheckRedirect callback to the HTTP opamp
client. It also unifies the API for CheckRedirect between WS and HTTP,
so that the same callback can be used in either circumstance.

Signed-off-by: Eric Chlebek <[email protected]>
  • Loading branch information
echlebek committed Dec 19, 2024
1 parent da944f3 commit b711c39
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 22 deletions.
88 changes: 88 additions & 0 deletions client/httpclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package client
import (
"compress/gzip"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"

"github.com/open-telemetry/opamp-go/client/internal"
Expand Down Expand Up @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) {
// Shutdown the Server.
srv.Close()
}

func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock {
m := &checkRedirectMock{
t: t,
viaLen: viaLen,
http: true,
}
m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err)
return m
}

func TestRedirectHTTP(t *testing.T) {
redirectee := internal.StartMockServer(t)
tests := []struct {
Name string
Redirector *httptest.Server
ExpError bool
MockRedirect *checkRedirectMock
}{
{
Name: "simple redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
},
{
Name: "check redirect",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirectHTTP(t, 1, nil),
},
{
Name: "check redirect returns error",
Redirector: redirectServer("http://"+redirectee.Endpoint, 302),
MockRedirect: mockRedirect(t, 1, errors.New("hello")),
ExpError: true,
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
var connectErr atomic.Value
var connected atomic.Value

settings := &types.StartSettings{
Callbacks: types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
},
}
if test.MockRedirect != nil {
settings.Callbacks = types.Callbacks{
OnConnect: func(ctx context.Context) {
connected.Store(1)
},
OnConnectFailed: func(ctx context.Context, err error) {
connectErr.Store(err)
},
CheckRedirect: test.MockRedirect.CheckRedirect,
}
}
reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil
settings.OpAMPServerURL = reURL.String()
client := NewHTTP(nil)
prepareClient(t, settings, client)

err := client.Start(context.Background(), *settings)
if err != nil {
t.Fatal(err)
}
defer client.Stop(context.Background())
// Wait for connection to be established.
eventually(t, func() bool {
return connected.Load() != nil || connectErr.Load() != nil
})
if test.ExpError && connectErr.Load() == nil {
t.Error("expected non-nil error")
} else if err := connectErr.Load(); !test.ExpError && err != nil {
t.Fatal(err)
}
})
}
}
8 changes: 8 additions & 0 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func (h *HTTPSender) Run(
h.callbacks = callbacks
h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex)

// we need to detect if the redirect was ever set, if not, we want default behaviour
if callbacks.CheckRedirect != nil {
h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// viaResp only non-nil for ws client
return callbacks.CheckRedirect(req, via, nil)
}
}

for {
pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs)))
select {
Expand Down
16 changes: 12 additions & 4 deletions client/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,18 @@ type Callbacks struct {
// OnCommand is called when the Server requests that the connected Agent perform a command.
OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error

// CheckRedirect is called before following a redirect. It is similar in
// nature to the CheckRedirect in net/http's Client. If the value is nil,
// then the http client's CheckRedirect will not be altered.
CheckRedirect func(req *http.Request, via []*http.Response) error
// CheckRedirect is called before following a redirect, allowing the client
// the opportunity to observe the redirect chain, and optionally terminate
// following redirects early.
//
// CheckRedirect is intended to be similar, although not exactly equivalent,
// to net/http.Client's CheckRedirect feature. Unlike in net/http, the via
// parameter is a slice of HTTP responses, instead of requests. This gives
// an opportunity to users to know what the exact response headers and
// status were. The request itself can be obtained from the response.
//
// The responses in the via parameter are passed with their bodies closed.
CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error
}

func (c *Callbacks) SetDefaults() {
Expand Down
16 changes: 14 additions & 2 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
return c.common.SendCustomMessage(message)
}

func viaReq(resps []*http.Response) []*http.Request {
reqs := make([]*http.Request, 0, len(resps))
for _, resp := range resps {
reqs = append(reqs, resp.Request)
}
return reqs

Check warning on line 165 in client/wsclient.go

View check run for this annotation

Codecov / codecov/patch

client/wsclient.go#L160-L165

Added lines #L160 - L165 were not covered by tests
}

// handleRedirect checks a failed websocket upgrade response for a 3xx response
// and a Location header. If found, it sets the URL to the location found in the
// header so that it is tried on the next retry, instead of the current URL.
Expand All @@ -182,7 +190,11 @@ func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) erro
// if CheckRedirect results in an error, it gets returned, terminating
// redirection. As with stdlib, the error is wrapped in url.Error.
if c.common.Callbacks.CheckRedirect != nil {
if err := c.common.Callbacks.CheckRedirect(nextRequest, c.responseChain); err != nil {
reqChain := make([]*http.Request, 0, len(c.responseChain))
for _, resp := range c.responseChain {
reqChain = append(reqChain, resp.Request)
}
if err := c.common.Callbacks.CheckRedirect(nextRequest, reqChain, c.responseChain); err != nil {
return &url.Error{
Op: "Get",
URL: nextRequest.URL.String(),
Expand Down Expand Up @@ -215,7 +227,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna
defer func() {
if err != nil && !redirecting {
c.responseChain = nil
if c.common.Callbacks != nil && !c.common.IsStopping() {
if !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
}
}
Expand Down
39 changes: 23 additions & 16 deletions client/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,22 +328,31 @@ type checkRedirectMock struct {
mock.Mock
t testing.TB
viaLen int
http bool
}

func (c *checkRedirectMock) CheckRedirect(req *http.Request, via []*http.Response) error {
func (c *checkRedirectMock) CheckRedirect(req *http.Request, viaReq []*http.Request, via []*http.Response) error {
if req == nil {
c.t.Error("nil request in CheckRedirect")
return errors.New("nil request in CheckRedirect")
}
if len(via) > c.viaLen {
c.t.Error("via should be shorter than viaLen")
if len(viaReq) > c.viaLen {
c.t.Error("viaReq should be shorter than viaLen")
}
location, err := via[len(via)-1].Location()
if err != nil {
c.t.Error(err)
if !c.http {
// websocket transport
if len(via) > c.viaLen {
c.t.Error("via should be shorter than viaLen")
}
}
if !c.http && len(via) > 0 {
location, err := via[len(via)-1].Location()
if err != nil {
c.t.Error(err)
}
// the URL of the request should match the location header of the last response
assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response")
}
// the URL of the request should match the location header of the last response
assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response")
return c.Called(req, via).Error(0)
}

Expand All @@ -352,7 +361,7 @@ func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock {
t: t,
viaLen: viaLen,
}
m.On("CheckRedirect", mock.Anything, mock.Anything).Return(err)
m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err)
return m
}

Expand Down Expand Up @@ -403,8 +412,6 @@ func TestRedirectWS(t *testing.T) {
settings := types.StartSettings{
Callbacks: types.Callbacks{
OnConnect: func(ctx context.Context) {
Callbacks: &types.Callbacks{
OnConnectFunc: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailed: func(ctx context.Context, err error) {
Expand All @@ -415,7 +422,7 @@ func TestRedirectWS(t *testing.T) {
},
}
if test.MockRedirect != nil {
settings.Callbacks.(*types.CallbacksStruct).CheckRedirectFunc = test.MockRedirect.CheckRedirect
settings.Callbacks.CheckRedirect = test.MockRedirect.CheckRedirect
}
reURL, err := url.Parse(test.Redirector.URL)
assert.NoError(t, err)
Expand Down Expand Up @@ -468,16 +475,16 @@ func TestRedirectWSFollowChain(t *testing.T) {
var connectErr atomic.Value
mr := mockRedirect(t, 2, nil)
settings := types.StartSettings{
Callbacks: types.CallbacksStruct{
OnConnectFunc: func(ctx context.Context) {
Callbacks: types.Callbacks{
OnConnect: func(ctx context.Context) {
atomic.StoreInt64(&connected, 1)
},
OnConnectFailedFunc: func(ctx context.Context, err error) {
OnConnectFailed: func(ctx context.Context, err error) {
if err != websocket.ErrBadHandshake {
connectErr.Store(err)
}
},
CheckRedirectFunc: mr.CheckRedirect,
CheckRedirect: mr.CheckRedirect,
},
}
reURL, err := url.Parse(redirector.URL)
Expand Down

0 comments on commit b711c39

Please sign in to comment.