Skip to content

Commit

Permalink
Add HeaderFunc to allow modifying headers before every request (#298)
Browse files Browse the repository at this point in the history
Adds a new HeaderFunc to the StartSettings that allows for dynamically editing the headers before each HTTP request made by the OpAMP library.

Closes #297
  • Loading branch information
BinaryFissionGames authored Sep 12, 2024
1 parent 7cdd395 commit b33ab76
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 13 deletions.
73 changes: 73 additions & 0 deletions client/clientimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,79 @@ func TestConnectWithHeader(t *testing.T) {
})
}

func TestConnectWithHeaderFunc(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
srv := internal.StartMockServer(t)
var conn atomic.Value
srv.OnConnect = func(r *http.Request) {
authHdr := r.Header.Get("Authorization")
assert.EqualValues(t, "Bearer 12345678", authHdr)
userAgentHdr := r.Header.Get("User-Agent")
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
conn.Store(true)
}

hf := func(header http.Header) http.Header {
header.Set("Authorization", "Bearer 12345678")
header.Set("User-Agent", "custom-agent/1.0")
return header
}

// Start a client.
settings := types.StartSettings{
OpAMPServerURL: "ws://" + srv.Endpoint,
HeaderFunc: hf,
}
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool { return conn.Load() != nil })

// Shutdown the Server and the client.
srv.Close()
_ = client.Stop(context.Background())
})
}

func TestConnectWithHeaderAndHeaderFunc(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
srv := internal.StartMockServer(t)
var conn atomic.Value
srv.OnConnect = func(r *http.Request) {
authHdr := r.Header.Get("Authorization")
assert.EqualValues(t, "Bearer 12345678", authHdr)
userAgentHdr := r.Header.Get("User-Agent")
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
conn.Store(true)
}

baseHeader := http.Header{}
baseHeader.Set("User-Agent", "custom-agent/1.0")

hf := func(header http.Header) http.Header {
header.Set("Authorization", "Bearer 12345678")
return header
}

// Start a client.
settings := types.StartSettings{
OpAMPServerURL: "ws://" + srv.Endpoint,
Header: baseHeader,
HeaderFunc: hf,
}
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool { return conn.Load() != nil })

// Shutdown the Server and the client.
srv.Close()
_ = client.Stop(context.Background())
})
}

func TestConnectWithTLS(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
Expand Down
2 changes: 1 addition & 1 deletion client/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (c *httpClient) Start(ctx context.Context, settings types.StartSettings) er
c.opAMPServerURL = settings.OpAMPServerURL

// Prepare Server connection settings.
c.sender.SetRequestHeader(settings.Header)
c.sender.SetRequestHeader(settings.Header, settings.HeaderFunc)

// Add TLS configuration into httpClient
c.sender.AddTLSConfig(settings.TLSConfig)
Expand Down
33 changes: 24 additions & 9 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type HTTPSender struct {
compressionEnabled bool

// Headers to send with all requests.
requestHeader http.Header
getHeader func() http.Header

// Processor to handle received messages.
receiveProcessor receivedProcessor
Expand All @@ -75,7 +75,7 @@ func NewHTTPSender(logger types.Logger) *HTTPSender {
pollingIntervalMs: defaultPollingIntervalMs,
}
// initialize the headers with no additional headers
h.SetRequestHeader(nil)
h.SetRequestHeader(nil, nil)
return h
}

Expand Down Expand Up @@ -121,12 +121,26 @@ func (h *HTTPSender) Run(

// SetRequestHeader sets additional HTTP headers to send with all future requests.
// Should not be called concurrently with any other method.
func (h *HTTPSender) SetRequestHeader(header http.Header) {
if header == nil {
header = http.Header{}
func (h *HTTPSender) SetRequestHeader(baseHeaders http.Header, headerFunc func(http.Header) http.Header) {
if baseHeaders == nil {
baseHeaders = http.Header{}
}

if headerFunc == nil {
headerFunc = func(h http.Header) http.Header {
return h
}
}

h.getHeader = func() http.Header {
requestHeader := headerFunc(baseHeaders.Clone())
requestHeader.Set(headerContentType, contentTypeProtobuf)
if h.compressionEnabled {
requestHeader.Set(headerContentEncoding, encodingTypeGZip)
}

return requestHeader
}
h.requestHeader = header
h.requestHeader.Set(headerContentType, contentTypeProtobuf)
}

// makeOneRequestRoundtrip sends a request and receives a response.
Expand Down Expand Up @@ -255,7 +269,7 @@ func (h *HTTPSender) prepareRequest(ctx context.Context) (*requestWrapper, error
return nil, err
}

req.Header = h.requestHeader
req.Header = h.getHeader()
return &req, nil
}

Expand Down Expand Up @@ -295,9 +309,10 @@ func (h *HTTPSender) SetPollingInterval(duration time.Duration) {
atomic.StoreInt64(&h.pollingIntervalMs, duration.Milliseconds())
}

// EnableCompression enables compression for the sender.
// Should not be called concurrently with Run.
func (h *HTTPSender) EnableCompression() {
h.compressionEnabled = true
h.requestHeader.Set(headerContentEncoding, encodingTypeGZip)
}

func (h *HTTPSender) AddTLSConfig(config *tls.Config) {
Expand Down
5 changes: 5 additions & 0 deletions client/types/startsettings.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ type StartSettings struct {
// Optional additional HTTP headers to send with all HTTP requests.
Header http.Header

// Optional function that can be used to modify the HTTP headers
// before each HTTP request.
// Can modify and return the argument or return the argument without modifying.
HeaderFunc func(http.Header) http.Header

// Optional TLS config for HTTP connection.
TLSConfig *tls.Config

Expand Down
20 changes: 17 additions & 3 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type wsClient struct {
url *url.URL

// HTTP request headers to use when connecting to OpAMP Server.
requestHeader http.Header
getHeader func() http.Header

// Websocket dialer and connection.
dialer websocket.Dialer
Expand Down Expand Up @@ -86,7 +86,21 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro
}
c.dialer.TLSClientConfig = settings.TLSConfig

c.requestHeader = settings.Header
headerFunc := settings.HeaderFunc
if headerFunc == nil {
headerFunc = func(h http.Header) http.Header {
return h
}
}

baseHeader := settings.Header
if baseHeader == nil {
baseHeader = http.Header{}
}

c.getHeader = func() http.Header {
return headerFunc(baseHeader.Clone())
}

c.common.StartConnectAndRun(c.runUntilStopped)

Expand Down Expand Up @@ -142,7 +156,7 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
// by the Server.
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
var resp *http.Response
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader)
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
if err != nil {
if c.common.Callbacks != nil && !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
Expand Down

0 comments on commit b33ab76

Please sign in to comment.