diff --git a/oidc_gateway.go b/oidc_gateway.go index 7846a5b..14a689b 100644 --- a/oidc_gateway.go +++ b/oidc_gateway.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "math/big" "net" "net/http" @@ -16,18 +15,18 @@ import ( ) type JWK struct { - N string - Kty string - Kid string - Alg string - E string - Use string - X5c []string - X5t string + N string `json:"n"` + Kty string `json:"kty"` + Kid string `json:"kid"` + Alg string `json:"alg"` + E string `json:"e"` + Use string `json:"use"` + X5c []string `json:"x5c"` + X5t string `json:"x5t"` } type JWKS struct { - Keys []JWK + Keys []JWK `json:"keys"` } type GatewayContext struct { @@ -38,25 +37,25 @@ type GatewayContext struct { func getKeyFromJwks(jwksBytes []byte) func(*jwt.Token) (interface{}, error) { return func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } var jwks JWKS if err := json.Unmarshal(jwksBytes, &jwks); err != nil { - return nil, fmt.Errorf("Unable to parse JWKS") + return nil, fmt.Errorf("unable to parse JWKS: %v", err) } for _, jwk := range jwks.Keys { if jwk.Kid == token.Header["kid"] { nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { - return nil, fmt.Errorf("Unable to parse key") + return nil, fmt.Errorf("unable to parse key N: %v", err) } var n big.Int eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E) if err != nil { - return nil, fmt.Errorf("Unable to parse key") + return nil, fmt.Errorf("unable to parse key E: %v", err) } var e big.Int @@ -69,7 +68,7 @@ func getKeyFromJwks(jwksBytes []byte) func(*jwt.Token) (interface{}, error) { } } - return nil, fmt.Errorf("Unknown kid: %v", token.Header["kid"]) + return nil, fmt.Errorf("unknown kid: %v", token.Header["kid"]) } } @@ -80,14 +79,12 @@ func validateTokenCameFromGitHub(oidcTokenString string, gc *GatewayContext) (jw if now.Sub(gc.jwksLastUpdate) > time.Minute || len(gc.jwksCache) == 0 { resp, err := http.Get("https://token.actions.githubusercontent.com/.well-known/jwks") if err != nil { - fmt.Println(err) - return nil, fmt.Errorf("Unable to get JWKS configuration") + return nil, fmt.Errorf("unable to get JWKS configuration: %v", err) } - jwksBytes, err := ioutil.ReadAll(resp.Body) + jwksBytes, err := io.ReadAll(resp.Body) if err != nil { - fmt.Println(err) - return nil, fmt.Errorf("Unable to get JWKS configuration") + return nil, fmt.Errorf("unable to read JWKS configuration: %v", err) } gc.jwksCache = jwksBytes @@ -95,14 +92,14 @@ func validateTokenCameFromGitHub(oidcTokenString string, gc *GatewayContext) (jw } // Attempt to validate JWT with JWKS - oidcToken, err := jwt.Parse(string(oidcTokenString), getKeyFromJwks(gc.jwksCache)) + oidcToken, err := jwt.Parse(oidcTokenString, getKeyFromJwks(gc.jwksCache)) if err != nil || !oidcToken.Valid { - return nil, fmt.Errorf("Unable to validate JWT") + return nil, fmt.Errorf("unable to validate JWT: %v", err) } claims, ok := oidcToken.Claims.(jwt.MapClaims) if !ok { - return nil, fmt.Errorf("Unable to map JWT claims") + return nil, fmt.Errorf("unable to map JWT claims") } return claims, nil @@ -117,7 +114,6 @@ func transfer(destination io.WriteCloser, source io.ReadCloser) { func handleProxyRequest(w http.ResponseWriter, req *http.Request) { proxyConn, err := net.DialTimeout("tcp", req.Host, 5*time.Second) if err != nil { - fmt.Println(err) http.Error(w, http.StatusText(http.StatusRequestTimeout), http.StatusRequestTimeout) return } @@ -126,14 +122,13 @@ func handleProxyRequest(w http.ResponseWriter, req *http.Request) { hijacker, ok := w.(http.Hijacker) if !ok { - fmt.Println("Connection hijacking not supported") + // Connection hijacking not supported http.Error(w, http.StatusText(http.StatusExpectationFailed), http.StatusExpectationFailed) return } reqConn, _, err := hijacker.Hijack() if err != nil { - fmt.Println(err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } @@ -145,7 +140,6 @@ func handleProxyRequest(w http.ResponseWriter, req *http.Request) { func handleApiRequest(w http.ResponseWriter) { resp, err := http.Get("https://www.bing.com") if err != nil { - fmt.Println(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } @@ -165,10 +159,13 @@ func (gatewayContext *GatewayContext) ServeHTTP(w http.ResponseWriter, req *http // This only means the OIDC token came from any GitHub Actions workflow, // we *must* check claims specific to our use case below oidcTokenString := string(req.Header.Get("Gateway-Authorization")) + if oidcTokenString == "" { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } claims, err := validateTokenCameFromGitHub(oidcTokenString, gatewayContext) if err != nil { - fmt.Println(err) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -194,7 +191,6 @@ func (gatewayContext *GatewayContext) ServeHTTP(w http.ResponseWriter, req *http if claims["aud"] != "api://ActionsOIDCGateway" { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return - } // Now that claims have been verified, we can service the request @@ -217,5 +213,7 @@ func main() { WriteTimeout: 60 * time.Second, } - server.ListenAndServe() + if err := server.ListenAndServe(); err != nil { + fmt.Printf("server error: %v\n", err) + } } diff --git a/oidc_gateway_test.go b/oidc_gateway_test.go index 11c527e..5deba30 100644 --- a/oidc_gateway_test.go +++ b/oidc_gateway_test.go @@ -45,7 +45,7 @@ func TestGetKeyForTokenMaker(t *testing.T) { // Test token referencing unknown key token.Header["kid"] = "unknownKey" - key, err = getKeyFunc(token) + _, err = getKeyFunc(token) if err == nil { t.Error("Should fail when passed unknown key") } @@ -53,7 +53,7 @@ func TestGetKeyForTokenMaker(t *testing.T) { // Test token fails with any other signing key than RSA tokenHmac := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenClaims) - key, err = getKeyFunc(tokenHmac) + _, err = getKeyFunc(tokenHmac) if err == nil { t.Error("Should fail any signing algorithm other than RSA") } @@ -107,18 +107,18 @@ func TestValidateTokenCameFromGitHub(t *testing.T) { panic(err) } - claims, err = validateTokenCameFromGitHub(signedToken, gatewayContext) + _, err = validateTokenCameFromGitHub(signedToken, gatewayContext) if err == nil { t.Error("Should not validate token signed with other key") } // Test unsigned token is not allowed - unsigendToken := jwt.NewWithClaims(jwt.SigningMethodNone, tokenClaims) - unsigendToken.Header["kid"] = "testKey" + unsignedToken := jwt.NewWithClaims(jwt.SigningMethodNone, tokenClaims) + unsignedToken.Header["kid"] = "testKey" - noneToken, err := token.SignedString("none signing method allowed") + noneToken, _ := token.SignedString("none signing method allowed") - claims, err = validateTokenCameFromGitHub(noneToken, gatewayContext) + _, err = validateTokenCameFromGitHub(noneToken, gatewayContext) if err == nil { t.Error("Should not validate unsigned token") }