Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support function authentication with OpenFaaS IAM #996

Merged
merged 6 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions commands/general.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package commands

import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"os"
"time"

"github.com/openfaas/faas-cli/proxy"
"github.com/openfaas/faas-cli/config"
"github.com/openfaas/go-sdk"
)

Expand Down Expand Up @@ -49,21 +50,67 @@ func GetDefaultSDKClient() (*sdk.Client, error) {
return nil, err
}

transport := GetDefaultCLITransport(tlsInsecure, &commandTimeout)
authConfig, err := config.LookupAuthConfig(gatewayURL.String())
if err != nil {
fmt.Printf("Failed to lookup auth config: %s\n", err)
}

var clientAuth sdk.ClientAuth
var functionTokenSource sdk.TokenSource
if authConfig.Auth == config.BasicAuthType {
username, password, err := config.DecodeAuth(authConfig.Token)
if err != nil {
return nil, err
}

clientAuth = &sdk.BasicAuth{
Username: username,
Password: password,
}
}

if authConfig.Auth == config.Oauth2AuthType {
tokenAuth := &StaticTokenAuth{
token: authConfig.Token,
}

clientAuth = tokenAuth
functionTokenSource = tokenAuth
}

// User specified token gets priority
if len(token) > 0 {
tokenAuth := &StaticTokenAuth{
token: token,
}

clientAuth = tokenAuth
functionTokenSource = tokenAuth
}

httpClient := &http.Client{}
httpClient.Timeout = commandTimeout

transport := GetDefaultCLITransport(tlsInsecure, &commandTimeout)
if transport != nil {
httpClient.Transport = transport
}

clientAuth, err := proxy.NewCLIAuth(token, gatewayAddress)
if err != nil {
return nil, err
}
return sdk.NewClientWithOpts(gatewayURL, httpClient,
sdk.WithAuthentication(clientAuth),
sdk.WithFunctionTokenSource(functionTokenSource),
), nil
}

type StaticTokenAuth struct {
token string
}

client := sdk.NewClient(gatewayURL, clientAuth, http.DefaultClient)
func (a *StaticTokenAuth) Set(req *http.Request) error {
req.Header.Add("Authorization", "Bearer "+a.token)
return nil
}

return client, nil
func (ts *StaticTokenAuth) Token() (string, error) {
return ts.token, nil
}
202 changes: 171 additions & 31 deletions commands/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
package commands

import (
"bytes"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"os"
"runtime"
"strings"

"github.com/alexellis/hmac"
"github.com/openfaas/faas-cli/proxy"
"github.com/openfaas/faas-cli/stack"
"github.com/openfaas/faas-cli/version"
"github.com/spf13/cobra"
)

Expand All @@ -24,8 +28,11 @@ var (
sigHeader string
key string
functionInvokeNamespace string
authenticate bool
)

const functionInvokeRealm = "IAM function invoke"

func init() {
// Setup flags that are used by multiple commands (variables defined in faas.go)
invokeCmd.Flags().StringVar(&functionName, "name", "", "Name of the deployed function")
Expand All @@ -36,6 +43,7 @@ func init() {
invokeCmd.Flags().StringVar(&contentType, "content-type", "text/plain", "The content-type HTTP header such as application/json")
invokeCmd.Flags().StringArrayVar(&query, "query", []string{}, "pass query-string options")
invokeCmd.Flags().StringArrayVarP(&headers, "header", "H", []string{}, "pass HTTP request header")
invokeCmd.Flags().BoolVar(&authenticate, "auth", false, "Authenticate with an OpenFaaS token when invoking the function")
invokeCmd.Flags().BoolVarP(&invokeAsync, "async", "a", false, "Invoke the function asynchronously")
invokeCmd.Flags().StringVarP(&httpMethod, "method", "m", "POST", "pass HTTP request method")
invokeCmd.Flags().BoolVar(&tlsInsecure, "tls-no-verify", false, "Disable TLS validation")
Expand Down Expand Up @@ -63,32 +71,32 @@ var invokeCmd = &cobra.Command{
}

func runInvoke(cmd *cobra.Command, args []string) error {
var services stack.Services

if len(args) < 1 {
return fmt.Errorf("please provide a name for the function")
}
functionName = args[0]

if missingSignFlag(sigHeader, key) {
return fmt.Errorf("signing requires both --sign <header-value> and --key <key-value>")
}

var yamlGateway string
functionName = args[0]
err := validateHTTPMethod(httpMethod)
if err != nil {
return nil
}

if len(yamlFile) > 0 {
welteki marked this conversation as resolved.
Show resolved Hide resolved
parsedServices, err := stack.ParseYAMLFile(yamlFile, regex, filter, envsubst)
if err != nil {
return err
}
httpHeader, err := parseHeaders(headers)
if err != nil {
return err
}

if parsedServices != nil {
services = *parsedServices
yamlGateway = services.Provider.GatewayURL
}
httpQuery, err := parseQueryValues(query)
if err != nil {
return err
}

gatewayAddress := getGatewayURL(gateway, defaultGateway, yamlGateway, os.Getenv(openFaaSURLEnvironment))
httpHeader.Set("Content-Type", contentType)
httpHeader.Set("User-Agent", fmt.Sprintf("faas-cli/%s (openfaas; %s; %s)", version.BuildVersion(), runtime.GOOS, runtime.GOARCH))

stat, _ := os.Stdin.Stat()
if (stat.Mode() & os.ModeCharDevice) != 0 {
Expand All @@ -101,38 +109,170 @@ func runInvoke(cmd *cobra.Command, args []string) error {
}

if len(sigHeader) > 0 {
signedHeader, err := generateSignedHeader(functionInput, key, sigHeader)
if err != nil {
return fmt.Errorf("unable to sign message: %s", err.Error())
}
headers = append(headers, signedHeader)
sig := generateSignature(functionInput, key)
httpHeader.Add(sigHeader, sig)
}

response, err := proxy.InvokeFunction(gatewayAddress, functionName, &functionInput, contentType, query, headers, invokeAsync, httpMethod, tlsInsecure, functionInvokeNamespace)
client, err := GetDefaultSDKClient()
if err != nil {
return err
}

if response != nil {
os.Stdout.Write(*response)
u, _ := url.Parse("/")
u.RawQuery = httpQuery.Encode()

body := bytes.NewReader(functionInput)
req, err := http.NewRequest(httpMethod, u.String(), body)
if err != nil {
return err
}
req.Header = httpHeader

return nil
}
res, err := client.InvokeFunction(functionName, functionInvokeNamespace, invokeAsync, authenticate, req)
if err != nil {
return fmt.Errorf("failed to invoke function: %s", err)
}
if res.Body != nil {
defer res.Body.Close()
}

if !authenticate && res.StatusCode == http.StatusUnauthorized {
authenticateHeader := res.Header.Get("WWW-Authenticate")
realm := getRealm(authenticateHeader)

// Retry the request and authenticate with an OpenFaaS function access token if the realm directive in the
// WWW-Authenticate header is the function invoke realm.
if realm == functionInvokeRealm {
authenticate := true
body := bytes.NewReader(functionInput)
req, err := http.NewRequest(httpMethod, u.String(), body)
if err != nil {
return err
}
req.Header = httpHeader

res, err = client.InvokeFunction(functionName, functionInvokeNamespace, invokeAsync, authenticate, req)
if err != nil {
return fmt.Errorf("failed to invoke function: %s", err)
}
if res.Body != nil {
defer res.Body.Close()
}
}
}

func generateSignedHeader(message []byte, key string, headerName string) (string, error) {
if code := res.StatusCode; code < 200 || code > 299 {
resBody, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("cannot read result from OpenFaaS on URL: %s %s", gateway, err)
}

return fmt.Errorf("server returned unexpected status code: %d - %s", res.StatusCode, string(resBody))
}

if len(headerName) == 0 {
return "", fmt.Errorf("signed header must have a non-zero length")
if invokeAsync && res.StatusCode == http.StatusAccepted {
fmt.Fprintf(os.Stderr, "Function submitted asynchronously.\n")
return nil
}

if _, err := io.Copy(os.Stdout, res.Body); err != nil {
return fmt.Errorf("cannot read result from OpenFaaS on URL: %s %s", gateway, err)
}

return nil
}

func generateSignature(message []byte, key string) string {
hash := hmac.Sign(message, []byte(key))
signature := hex.EncodeToString(hash)
signedHeader := fmt.Sprintf(`%s=%s=%s`, headerName, "sha1", string(signature[:]))

return signedHeader, nil
return fmt.Sprintf(`%s=%s`, "sha1", string(signature[:]))
}

func missingSignFlag(header string, key string) bool {
return (len(header) > 0 && len(key) == 0) || (len(header) == 0 && len(key) > 0)
}

// parseHeaders parses header values from the header command flag
func parseHeaders(headers []string) (http.Header, error) {
httpHeader := http.Header{}

for _, header := range headers {
headerVal := strings.SplitN(header, "=", 2)
if len(headerVal) != 2 {
return httpHeader, fmt.Errorf("the --header or -H flag must take the form of key=value")
}

key, value := headerVal[0], headerVal[1]
if key == "" {
return httpHeader, fmt.Errorf("the --header or -H flag must take the form of key=value (empty key given)")
}

if value == "" {
return httpHeader, fmt.Errorf("the --header or -H flag must take the form of key=value (empty value given)")
}

httpHeader.Add(key, value)
}

return httpHeader, nil
}

// parseQueryValues parses query values from the query command flags
func parseQueryValues(query []string) (url.Values, error) {
v := url.Values{}

for _, q := range query {
queryVal := strings.SplitN(q, "=", 2)
if len(queryVal) != 2 {
return v, fmt.Errorf("the --query flag must take the form of key=value")
}

key, value := queryVal[0], queryVal[1]
if key == "" {
return v, fmt.Errorf("the --header or -H flag must take the form of key=value (empty key given)")
}

if value == "" {
return v, fmt.Errorf("the --header or -H flag must take the form of key=value (empty value given)")
}

v.Add(key, value)
}

return v, nil
}

// validateMethod validates the HTTP request method
func validateHTTPMethod(httpMethod string) error {
var allowedMethods = []string{
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete,
}
helpString := strings.Join(allowedMethods, "/")

if !contains(allowedMethods, httpMethod) {
return fmt.Errorf("the --method or -m flag must take one of these values (%s)", helpString)
}
return nil
}

// NOTE: This is far from a fully compliant parser per RFC 7235.
// It is only intended to correctly capture the realm directive in the
// known format as returned by the OpenFaaS watchdogs.
func getRealm(headerVal string) string {
parts := strings.SplitN(headerVal, " ", 2)

realm := ""
if len(parts) > 1 {
directives := strings.Split(parts[1], ", ")

for _, part := range directives {
if strings.HasPrefix(part, "realm=") {
realm = strings.Trim(part[6:], `"`)
break
}
}
}

return realm
}
Loading