Skip to content

Commit

Permalink
Add identity implementation (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman authored Jul 15, 2024
1 parent fe436d0 commit 5924d37
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 8 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module github.com/restatedev/sdk-go
go 1.22.0

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/mr-tron/base58 v1.2.0
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0
github.com/stretchr/testify v1.9.0
github.com/vmihailenco/msgpack/v5 v5.4.1
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk=
Expand Down
30 changes: 30 additions & 0 deletions internal/identity/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package identity

import "fmt"

const SIGNATURE_SCHEME_HEADER = "X-Restate-Signature-Scheme"

type SignatureScheme string

var (
SchemeUnsigned SignatureScheme = "unsigned"
errMissingIdentity = fmt.Errorf("request has no identity")
)

func ValidateRequestIdentity(keySet KeySetV1, path string, headers map[string][]string) error {
switch len(headers[SIGNATURE_SCHEME_HEADER]) {
case 0:
return errMissingIdentity
case 1:
switch SignatureScheme(headers[SIGNATURE_SCHEME_HEADER][0]) {
case SchemeV1:
return validateV1(keySet, path, headers)
case SchemeUnsigned:
return errMissingIdentity
default:
return fmt.Errorf("unexpected signature scheme %v, allowed values are [%s %s]", headers[SIGNATURE_SCHEME_HEADER][0], SchemeUnsigned, SchemeV1)
}
default:
return fmt.Errorf("unexpected multi-value signature scheme header: %v", headers[SIGNATURE_SCHEME_HEADER])
}
}
83 changes: 83 additions & 0 deletions internal/identity/v1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package identity

import (
"crypto/ed25519"
"fmt"
"strings"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/mr-tron/base58"
)

const (
JWT_HEADER = "X-Restate-Jwt-V1"
SchemeV1 SignatureScheme = "v1"
)

type KeySetV1 = map[string]ed25519.PublicKey

func validateV1(keySet KeySetV1, path string, headers map[string][]string) error {
switch len(headers[JWT_HEADER]) {
case 0:
return fmt.Errorf("v1 signature scheme expects the following headers: [%s]", JWT_HEADER)
case 1:
default:
return fmt.Errorf("unexpected multi-value JWT header: %v", headers[JWT_HEADER])
}

token, err := jwt.Parse(headers[JWT_HEADER][0], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}

kid, ok := token.Header["kid"]
if !ok {
return nil, fmt.Errorf("Token missing 'kid' header field")
}

kidS, ok := kid.(string)
if !ok {
return nil, fmt.Errorf("Token 'kid' header field was not a string: %v", kid)
}

key, ok := keySet[kidS]
if !ok {
return nil, fmt.Errorf("Key ID %s is not present in key set", kid)
}

return key, nil
}, jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithAudience(path), jwt.WithExpirationRequired())
if err != nil {
return fmt.Errorf("failed to validate v1 request identity jwt: %w", err)
}

nbf, _ := token.Claims.GetNotBefore()
if nbf == nil {
// jwt library only validates nbf if its present, so we should check it was present
return fmt.Errorf("'nbf' claim is missing in v1 request identity jwt")
}

return nil
}

func ParseKeySetV1(keys []string) (KeySetV1, error) {
out := make(KeySetV1, len(keys))
for _, key := range keys {
if !strings.HasPrefix(key, "publickeyv1_") {
return nil, fmt.Errorf("v1 public key must start with 'publickeyv1_'")
}

pubBytes, err := base58.Decode(key[len("publickeyv1_"):])
if err != nil {
return nil, fmt.Errorf("v1 public key must be valid base58: %w", err)
}

if len(pubBytes) != ed25519.PublicKeySize {
return nil, fmt.Errorf("v1 public key must have exactly %d bytes, found %d", ed25519.PublicKeySize, len(pubBytes))
}

out[key] = ed25519.PublicKey(pubBytes)
}

return out, nil
}
2 changes: 1 addition & 1 deletion internal/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (t stringerValue[T]) LogValue() slog.Value {
}

func Stringer[T fmt.Stringer](key string, value T) slog.Attr {
return slog.Any(key, slog.AnyValue(stringerValue[T]{value}))
return slog.Any(key, stringerValue[T]{value})
}

func Error(err error) slog.Attr {
Expand Down
42 changes: 35 additions & 7 deletions server/restate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/restatedev/sdk-go/generated/proto/discovery"
"github.com/restatedev/sdk-go/generated/proto/protocol"
"github.com/restatedev/sdk-go/internal"
"github.com/restatedev/sdk-go/internal/identity"
"github.com/restatedev/sdk-go/internal/log"
"github.com/restatedev/sdk-go/internal/state"
"golang.org/x/net/http2"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Restate struct {
dropReplayLogs bool
systemLog *slog.Logger
routers map[string]restate.Router
keyIDs []string
keySet identity.KeySetV1
}

// NewRestate creates a new instance of Restate server
Expand All @@ -69,6 +72,11 @@ func (r *Restate) WithLogger(h slog.Handler, dropReplayLogs bool) *Restate {
return r
}

func (r *Restate) WithIdentityV1(keys ...string) *Restate {
r.keyIDs = append(r.keyIDs, keys...)
return r
}

func (r *Restate) Bind(router restate.Router) *Restate {
if _, ok := r.routers[router.Name()]; ok {
// panic because this is a programming error
Expand Down Expand Up @@ -120,38 +128,37 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request)

acceptVersionsString := req.Header.Get("accept")
if acceptVersionsString == "" {
writer.Write([]byte("missing accept header"))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte("missing accept header"))

return
}

serviceDiscoveryProtocolVersion := selectSupportedServiceDiscoveryProtocolVersion(acceptVersionsString)

if serviceDiscoveryProtocolVersion == discovery.ServiceDiscoveryProtocolVersion_SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED {
writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString)))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString)))
return
}

response, err := r.discover()
if err != nil {
writer.Write([]byte(err.Error()))
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))

return
}

bytes, err := json.Marshal(response)
if err != nil {
writer.Write([]byte(err.Error()))
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))

return
}

writer.Header().Add("Content-Type", serviceDiscoveryProtocolVersionToHeaderValue(serviceDiscoveryProtocolVersion))
writer.WriteHeader(200)
if _, err := writer.Write(bytes); err != nil {
r.systemLog.LogAttrs(req.Context(), slog.LevelError, "Failed to write discovery information", log.Error(err))
}
Expand Down Expand Up @@ -252,6 +259,17 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer
}

func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if r.keySet != nil {
if err := identity.ValidateRequestIdentity(r.keySet, request.RequestURI, request.Header); err != nil {
r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Rejecting request as its JWT did not validate", log.Error(err))

writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("Unauthorized"))

return
}
}

if request.RequestURI == "/discover" {
r.discoverHandler(writer, request)
return
Expand All @@ -261,8 +279,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if serviceProtocolVersionString == "" {
r.systemLog.ErrorContext(request.Context(), "Missing content-type header")

writer.Write([]byte("missing content-type header"))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte("missing content-type header"))

return
}
Expand All @@ -272,8 +290,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if !isServiceProtocolVersionSupported(serviceProtocolVersion) {
r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Unsupported service protocol version", slog.String("version", serviceProtocolVersionString))

writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString)))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString)))

return
}
Expand All @@ -297,6 +315,16 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
}

func (r *Restate) Start(ctx context.Context, address string) error {
if r.keyIDs == nil {
r.systemLog.WarnContext(ctx, "Accepting requests without validating request signatures; handler access must be restricted")
} else {
ks, err := identity.ParseKeySetV1(r.keyIDs)
if err != nil {
return fmt.Errorf("invalid request identity keys: %w", err)
}
r.keySet = ks
r.systemLog.LogAttrs(ctx, slog.LevelInfo, "Validating requests using signing keys", slog.Any("keys", r.keyIDs))
}

listener, err := net.Listen("tcp", address)
if err != nil {
Expand Down

0 comments on commit 5924d37

Please sign in to comment.