From 72113c9c3075755058881f5ba299b259977c11b2 Mon Sep 17 00:00:00 2001 From: Petteri Ponsimaa Date: Wed, 6 Nov 2024 17:43:35 +0200 Subject: [PATCH] fix: remove index.html security metas, improve cf validation --- doc/security.md | 14 ++- internal/conf/config.go | 28 ++++-- internal/conf/config.yaml | 5 +- internal/conf/defaults.go | 4 +- internal/httpcontroller/auth_routes.go | 2 +- internal/httpcontroller/middleware.go | 2 +- internal/httpcontroller/routes.go | 2 +- internal/httpcontroller/server.go | 5 +- internal/security/cloudflare.go | 83 +++++++++++---- internal/security/cloudflare_test.go | 133 ++++++++++++++++--------- internal/security/oauth.go | 34 +++++-- internal/security/oauth_test.go | 77 +++++++++++++- views/index.html | 3 - views/settings/securitySettings.html | 6 +- 14 files changed, 294 insertions(+), 104 deletions(-) diff --git a/doc/security.md b/doc/security.md index 536b98d0..1940f74e 100644 --- a/doc/security.md +++ b/doc/security.md @@ -61,9 +61,19 @@ security: ### Cloudflare Access Authentication Bypass -Cloudflare Access provides an authentication layer that uses your existing identity providers, such as Google or GitHub accounts, -to control access to your applications. When using Cloudflare Access for authentication, you can configure BirdNET-Go to trust traffic coming through the Cloudflare tunnel. The system authenticates requests by validating the `Cf-Access-Jwt-Assertion` header containing a JWT token from Cloudflare. +Cloudflare Access provides an authentication layer that uses your existing identity providers, such as Google or GitHub accounts, to control access to your applications. When using Cloudflare Access for authentication, you can configure BirdNET-Go to trust traffic coming through the Cloudflare tunnel. The system authenticates requests by validating the `Cf-Access-Jwt-Assertion` header containing a JWT token from Cloudflare. +To add even more security, you can also require that the Cloudflare Team Domain Name and Policy audience are valid in the JWT token. Enable these by defining them in the `config.yaml` file: + +```yaml +security: + allowcloudflarebypass: + enabled: true + teamdomain: "your-subdomain-of-cloudflareaccess.com" + audience: "your-policy-auddience" +``` + +See the following links for more information on Cloudflare Access: - [Cloudflare tunnels](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/) - [Create a remotely-managed tunnel](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/get-started/create-remote-tunnel/) - [Self-hosted applications](https://developers.cloudflare.com/cloudflare-one/applications/configure-apps/self-hosted-apps/) diff --git a/internal/conf/config.go b/internal/conf/config.go index 6d0fd832..bfd41ad1 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -226,6 +226,17 @@ type SocialProvider struct { UserId string // valid user id for OAuth2 } +type AllowSubnetBypass struct { + Enabled bool // true to enable subnet bypass + Subnet string // disable OAuth2 in subnet +} + +type AllowCloudflareBypass struct { + Enabled bool // true to enable CF Access + TeamDomain string // Cloudflare team domain + Audience string // Cloudflare policy audience +} + // SecurityConfig handles all security-related settings and validations // for the application, including authentication, TLS, and access control. type Security struct { @@ -239,16 +250,13 @@ type Security struct { // Let's Encrypt. Requires Host to be set and port 80/443 access. AutoTLS bool - RedirectToHTTPS bool // true to redirect to HTTPS - AllowSubnetBypass struct { - Enabled bool // true to enable subnet bypass - Subnet string // disable OAuth2 in subnet - } - AllowCloudflareBypass bool // disable OAuth2 in Cloudflare tunnel - BasicAuth BasicAuth // password authentication configuration - GoogleAuth SocialProvider // Google OAuth2 configuration - GithubAuth SocialProvider // Github OAuth2 configuration - SessionSecret string // secret for session cookie + RedirectToHTTPS bool // true to redirect to HTTPS + AllowSubnetBypass AllowSubnetBypass // subnet bypass configuration + AllowCloudflareBypass AllowCloudflareBypass // Cloudflare Access configuration + BasicAuth BasicAuth // password authentication configuration + GoogleAuth SocialProvider // Google OAuth2 configuration + GithubAuth SocialProvider // Github OAuth2 configuration + SessionSecret string // secret for session cookie } // Settings contains all configuration options for the BirdNET-Go application. diff --git a/internal/conf/config.yaml b/internal/conf/config.yaml index 4eb066fc..d8f4e2b0 100644 --- a/internal/conf/config.yaml +++ b/internal/conf/config.yaml @@ -130,7 +130,10 @@ security: allowsubnetbypass: enabled: false # true to disable OAuth in subnet subnet: "" # comma-separated list of CIDR ranges (e.g., "192.168.1.0/24,10.0.0.0/8") - allowcftunnelbypass: false # true to disable OAuth for Cloudflare Tunnel requests + allowcloudflarebypass: + enabled: false # true to disable bypass for Cloudflare Tunnel + teamdomain: "" # Cloudflare Tunnel team domain + audience: "" # Cloudflare Tunnel policy audience basicauth: enabled: false # true to enable basic auth password: "" # password hash for the settings interface diff --git a/internal/conf/defaults.go b/internal/conf/defaults.go index 69f0ad7b..40128485 100644 --- a/internal/conf/defaults.go +++ b/internal/conf/defaults.go @@ -166,7 +166,9 @@ func setDefaultConfig() { viper.SetDefault("security.redirecttohttps", false) viper.SetDefault("security.allowsubnetbypass.enabled", false) viper.SetDefault("security.allowsubnetbypass.subnet", "") - viper.SetDefault("security.allowcloudflaretunnelbypass", false) + viper.SetDefault("security.allowcloudflarebypass.enabled", false) + viper.SetDefault("security.allowcloudflarebypass.teamdomain", "") + viper.SetDefault("security.allowcloudflarebypass.audience", "") // Basic authentication configuration viper.SetDefault("security.basic.enabled", false) diff --git a/internal/httpcontroller/auth_routes.go b/internal/httpcontroller/auth_routes.go index fa036530..4c212c73 100644 --- a/internal/httpcontroller/auth_routes.go +++ b/internal/httpcontroller/auth_routes.go @@ -133,7 +133,7 @@ func (s *Server) handleLogout(c echo.Context) error { gothic.Logout(c.Response(), c.Request()) //nolint:errcheck // Handle Cloudflare logout if enabled - if s.Settings.Security.AllowCloudflareBypass && s.CloudflareAccess.IsEnabled(c) { + if s.CloudflareAccess.IsEnabled(c) { return s.CloudflareAccess.Logout(c) } diff --git a/internal/httpcontroller/middleware.go b/internal/httpcontroller/middleware.go index 7c48abbd..6789511f 100644 --- a/internal/httpcontroller/middleware.go +++ b/internal/httpcontroller/middleware.go @@ -55,7 +55,7 @@ func (s *Server) AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if isProtectedRoute(c.Path()) { // Check for Cloudflare bypass - if s.Settings.Security.AllowCloudflareBypass && s.CloudflareAccess.IsEnabled(c) { + if s.CloudflareAccess.IsEnabled(c) { return next(c) } diff --git a/internal/httpcontroller/routes.go b/internal/httpcontroller/routes.go index 8e1e0dd6..428e8637 100644 --- a/internal/httpcontroller/routes.go +++ b/internal/httpcontroller/routes.go @@ -147,7 +147,7 @@ func (s *Server) handlePageRequest(c echo.Context) error { path := c.Path() pageRoute, isPageRoute := s.pageRoutes[path] partialRoute, isFragment := s.partialRoutes[path] - isCloudflare := s.Settings.Security.AllowCloudflareBypass && s.CloudflareAccess.IsEnabled(c) + isCloudflare := s.CloudflareAccess.IsEnabled(c) // Return an error if route is unknown if !isPageRoute && !isFragment { diff --git a/internal/httpcontroller/server.go b/internal/httpcontroller/server.go index b7bd9274..d2ebcdee 100644 --- a/internal/httpcontroller/server.go +++ b/internal/httpcontroller/server.go @@ -49,7 +49,7 @@ func New(settings *conf.Settings, dataStore datastore.Interface, birdImageCache BirdImageCache: birdImageCache, AudioLevelChan: audioLevelChan, DashboardSettings: &settings.Realtime.Dashboard, - OAuth2Server: security.NewOAuth2Server(settings), + OAuth2Server: security.NewOAuth2Server(), CloudflareAccess: security.NewCloudflareAccess(), } @@ -107,8 +107,7 @@ func (s *Server) isAuthenticationEnabled(c echo.Context) bool { func (s *Server) IsAccessAllowed(c echo.Context) bool { // First check Cloudflare Access JWT - if s.Settings.Security.AllowCloudflareBypass && s.CloudflareAccess.IsEnabled(c) { - log.Printf("\033[1;35m*** IsAccessAllowed: Cloudflare Access token valid") + if s.CloudflareAccess.IsEnabled(c) { return true } diff --git a/internal/security/cloudflare.go b/internal/security/cloudflare.go index 99c68bc5..bad2d9d3 100644 --- a/internal/security/cloudflare.go +++ b/internal/security/cloudflare.go @@ -12,6 +12,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" + "github.com/tphakala/birdnet-go/internal/conf" ) type CloudflareAccessClaims struct { @@ -30,24 +31,34 @@ type CloudflareAccessClaims struct { type CloudflareAccess struct { certs map[string]string teamDomain string + audience string certCache struct { lastFetch time.Time mutex sync.RWMutex } + settings *conf.AllowCloudflareBypass + debug bool } func NewCloudflareAccess() *CloudflareAccess { + settings := conf.GetSettings() + cfBypass := settings.Security.AllowCloudflareBypass + return &CloudflareAccess{ - certs: make(map[string]string), + certs: make(map[string]string), + teamDomain: cfBypass.TeamDomain, + audience: cfBypass.Audience, certCache: struct { lastFetch time.Time mutex sync.RWMutex }{ lastFetch: time.Time{}, }, + settings: &cfBypass, } } +// fetchCertsIfNeeded fetches the certificates using a cache func (ca *CloudflareAccess) fetchCertsIfNeeded(issuer string) error { ca.certCache.mutex.RLock() cacheAge := time.Since(ca.certCache.lastFetch) @@ -73,9 +84,10 @@ func (ca *CloudflareAccess) fetchCertsIfNeeded(issuer string) error { // fetchCerts fetches the certificates from Cloudflare func (ca *CloudflareAccess) fetchCerts(issuer string) error { certsURL := fmt.Sprintf("%s/cdn-cgi/access/certs", issuer) - log.Printf("Fetching Cloudflare certs from URL: %s", certsURL) + ca.Debug("Fetching Cloudflare certs from URL: %s", certsURL) resp, err := http.Get(certsURL) + if err != nil { return fmt.Errorf("failed to fetch Cloudflare certs: %w", err) } @@ -90,14 +102,14 @@ func (ca *CloudflareAccess) fetchCerts(issuer string) error { Cert string `json:"cert"` } `json:"public_certs"` } - if err := json.NewDecoder(resp.Body).Decode(&certsResponse); err != nil { return fmt.Errorf("failed to decode certs response: %w", err) } + // Store the certificates with kids as keys for _, cert := range certsResponse.PublicCerts { ca.certs[cert.Kid] = cert.Cert - log.Printf("Added certificate with Kid: %s", cert.Kid) + ca.Debug("Added certificate with Kid: %s", cert.Kid) } return nil @@ -105,6 +117,11 @@ func (ca *CloudflareAccess) fetchCerts(issuer string) error { // IsEnabled returns true if Cloudflare Access is enabled func (ca *CloudflareAccess) IsEnabled(c echo.Context) bool { + + if !ca.settings.Enabled { + return false + } + claims, err := ca.VerifyAccessJWT(c.Request()) if err == nil && claims != nil { return true @@ -116,7 +133,7 @@ func (ca *CloudflareAccess) IsEnabled(c echo.Context) bool { func (ca *CloudflareAccess) VerifyAccessJWT(r *http.Request) (*CloudflareAccessClaims, error) { jwtToken := r.Header.Get("Cf-Access-Jwt-Assertion") if jwtToken == "" { - log.Println("No Cloudflare Access JWT found") + ca.Debug("No Cloudflare Access JWT found") return nil, fmt.Errorf("no Cloudflare Access JWT found") } @@ -125,7 +142,7 @@ func (ca *CloudflareAccess) VerifyAccessJWT(r *http.Request) (*CloudflareAccessC claims := &CloudflareAccessClaims{} token, _, err := parser.ParseUnverified(jwtToken, claims) if err != nil { - log.Printf("Failed to parse JWT: %v", err) + ca.Debug("Failed to parse JWT: %v", err) return nil, fmt.Errorf("failed to parse JWT: %w", err) } @@ -133,16 +150,23 @@ func (ca *CloudflareAccess) VerifyAccessJWT(r *http.Request) (*CloudflareAccessC if claims.Issuer != "" { parsedIssuer, err := url.Parse(claims.Issuer) if err != nil { - log.Printf("Invalid issuer URL: %v", err) + ca.Debug("Invalid issuer URL: %v", err) return nil, fmt.Errorf("invalid issuer URL: %w", err) } ca.teamDomain = strings.Split(parsedIssuer.Hostname(), ".")[0] + + // Validate team domain if configured + if ca.settings.TeamDomain != "" { + if ca.teamDomain != ca.settings.TeamDomain { + return nil, fmt.Errorf("team domain mismatch") + } + } } // Verify the JWT with the public key kid, ok := token.Header["kid"].(string) if !ok { - log.Println("No key ID in JWT header") + ca.Debug("No key ID in JWT header") return nil, fmt.Errorf("no key ID in JWT header") } @@ -154,7 +178,7 @@ func (ca *CloudflareAccess) VerifyAccessJWT(r *http.Request) (*CloudflareAccessC cert := ca.certs[kid] pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) if err != nil { - log.Printf("Failed to parse public key: %v", err) + ca.Debug("Failed to parse public key: %v", err) return nil, fmt.Errorf("failed to parse public key: %w", err) } @@ -164,35 +188,39 @@ func (ca *CloudflareAccess) VerifyAccessJWT(r *http.Request) (*CloudflareAccessC }) if err != nil { - log.Printf("Invalid JWT: %v", err) + ca.Debug("Invalid JWT: %v", err) return nil, fmt.Errorf("invalid JWT: %w", err) } if !token.Valid { - log.Println("Token is not valid") + ca.Debug("Token is not valid") return nil, fmt.Errorf("token is not valid") } if err := claims.Valid(); err != nil { - log.Printf("Invalid claims: %v", err) + ca.Debug("Invalid claims: %v", err) return nil, fmt.Errorf("invalid claims: %w", err) } - now := time.Now().Unix() - if claims.ExpiresAt < now { - log.Println("Token expired") - return nil, fmt.Errorf("token expired") - } - if claims.NotBefore > now { - log.Println("Token not yet valid") - return nil, fmt.Errorf("token not yet valid") + // Validate audience if configured + if ca.settings.Audience != "" { + audienceValid := false + for _, aud := range claims.Audience { + if aud == ca.settings.Audience { + audienceValid = true + break + } + } + if !audienceValid { + return nil, fmt.Errorf("audience mismatch") + } } + if claims.Type != "app" { - log.Printf("Invalid token type: %s", claims.Type) + ca.Debug("Invalid token type: %s", claims.Type) return nil, fmt.Errorf("invalid token type: %s", claims.Type) } - log.Println("Cloudflare Access JWT successfully verified") return claims, nil } @@ -238,6 +266,7 @@ func (ca *CloudflareAccess) Logout(c echo.Context) error { Value: "", Expires: time.Now().Add(-time.Hour), }) + ca.Debug("Logged out from Cloudflare Access") // Redirect to GetLogoutURL return c.Redirect(http.StatusFound, ca.GetLogoutURL()) @@ -246,3 +275,13 @@ func (ca *CloudflareAccess) Logout(c echo.Context) error { func (ca *CloudflareAccess) GetLogoutURL() string { return fmt.Sprintf("https://%s.cloudflareaccess.com/cdn-cgi/access/logout", ca.teamDomain) } + +func (ca *CloudflareAccess) Debug(format string, v ...interface{}) { + if !ca.debug { + if len(v) == 0 { + log.Print(format) + } else { + log.Printf(format, v...) + } + } +} diff --git a/internal/security/cloudflare_test.go b/internal/security/cloudflare_test.go index 251ada67..21bb4389 100644 --- a/internal/security/cloudflare_test.go +++ b/internal/security/cloudflare_test.go @@ -9,31 +9,62 @@ import ( "strings" "sync" "testing" + + //"github.com/prometheus/common/server" + "github.com/tphakala/birdnet-go/internal/conf" ) -// TestFetchCerts tests the fetchCerts method of the CloudflareAccess struct -func TestCloudflareAccess(t *testing.T) { +func BeforeEach(t *testing.T) (*httptest.Server, *CloudflareAccess) { + + // Add test certificates with proper PEM format + certsJSON := `{ + "public_certs": [ + { + "kid": "1234", + "cert": "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0test1\n-----END PUBLIC KEY-----" + }, + { + "kid": "5678", + "cert": "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0test2\n-----END PUBLIC KEY-----" + } + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/cdn-cgi/access/certs") { + t.Errorf("Expected request to /cdn-cgi/access/certs, got %s", r.URL.Path) + } + fmt.Fprintln(w, certsJSON) + })) + + // Set the settings instance + conf.Setting() + + ca := NewCloudflareAccess() + ca.debug = true + ca.teamDomain = "test-team" + ca.audience = "test-audience" + + ca.settings = &conf.AllowCloudflareBypass{ + Enabled: true, + TeamDomain: "test-team", + Audience: "test-audience", + } + + return server, ca +} + +func TestCloudflareAccessSuite(t *testing.T) { tests := []struct { name string - setup func() (*httptest.Server, *CloudflareAccess) - verify func(*testing.T, *CloudflareAccess, error) + setup func(*httptest.Server) + test func(*testing.T, *CloudflareAccess, *httptest.Server) wantErr bool }{ { name: "successful certificate fetch", - setup: func() (*httptest.Server, *CloudflareAccess) { - certsJSON := `{ - "public_certs": [ - {"kid": "1234", "cert": "cert1"}, - {"kid": "5678", "cert": "cert2"} - ] - }` - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, certsJSON) - })) - return server, NewCloudflareAccess() - }, - verify: func(t *testing.T, ca *CloudflareAccess, err error) { + test: func(t *testing.T, ca *CloudflareAccess, server *httptest.Server) { + err := ca.fetchCerts(server.URL) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -41,37 +72,41 @@ func TestCloudflareAccess(t *testing.T) { t.Fatalf("Expected 2 certificates, got %d", len(ca.certs)) } }, - wantErr: false, }, - // Add more test cases here + { + name: "unsuccessful certificate fetch", + test: func(t *testing.T, ca *CloudflareAccess, server *httptest.Server) { + err := ca.fetchCerts("/invalid-url") + if err == nil { + t.Fatalf("Expected error, none") + } + if len(ca.certs) != 0 { + t.Fatalf("Expected zero certificates, got %d", len(ca.certs)) + } + }, + }, + // Add more test cases } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server, ca := tt.setup() + server, ca := BeforeEach(t) defer server.Close() - err := ca.fetchCerts(server.URL) - tt.verify(t, ca, err) + if tt.setup != nil { + tt.setup(server) + } + + tt.test(t, ca, server) }) } } // TestFetchCertsSuccessProperlyUpdatesCertsMap tests the behavior of fetchCerts when the server returns a successful response func TestFetchCertsSuccessProperlyUpdatesCertsMap(t *testing.T) { - certsJSON := `{ - "public_certs": [ - {"kid": "1234", "cert": "cert1"}, - {"kid": "5678", "cert": "cert2"} - ] - }` - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, certsJSON) - })) + server, ca := BeforeEach(t) defer server.Close() - ca := &CloudflareAccess{certs: make(map[string]string)} err := ca.fetchCerts(server.URL) if err != nil { @@ -82,21 +117,26 @@ func TestFetchCertsSuccessProperlyUpdatesCertsMap(t *testing.T) { t.Fatalf("Expected 2 certificates, got %d", len(ca.certs)) } - if ca.certs["1234"] != "cert1" || ca.certs["5678"] != "cert2" { + expectedCert1 := "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0test1\n-----END PUBLIC KEY-----" + expectedCert2 := "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0test2\n-----END PUBLIC KEY-----" + + // Compare trimmed strings to handle whitespace differences + if strings.TrimSpace(ca.certs["1234"]) != strings.TrimSpace(expectedCert1) || + strings.TrimSpace(ca.certs["5678"]) != strings.TrimSpace(expectedCert2) { t.Fatalf("Certificates not stored correctly") } } // TestFetchCertsInvalidJSONResponse tests the behavior of fetchCerts when the server returns invalid JSON func TestFetchCertsInvalidJSONResponse(t *testing.T) { - certsJSON := `invalid JSON` - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, certsJSON) - })) + server, ca := BeforeEach(t) defer server.Close() - ca := &CloudflareAccess{certs: make(map[string]string)} + // Override server handler for invalid JSON case + server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `invalid JSON`) + }) + err := ca.fetchCerts(server.URL) if err == nil { @@ -107,10 +147,6 @@ func TestFetchCertsInvalidJSONResponse(t *testing.T) { if !strings.Contains(err.Error(), expectedErrMsg) { t.Fatalf("Expected error message to contain '%s', got '%s'", expectedErrMsg, err.Error()) } - - if len(ca.certs) != 0 { - t.Fatalf("Expected no certificates to be stored, got %d", len(ca.certs)) - } } // TestFetchCertsError tests the behavior of fetchCerts when the server returns an error @@ -126,13 +162,14 @@ func TestFetchCertsError(t *testing.T) { // TestFetchCertsEmptyResponse tests the behavior of fetchCerts when the server returns an empty response func TestFetchCertsEmptyResponse(t *testing.T) { + server, ca := BeforeEach(t) + defer server.Close() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Override server handler for empty response case + server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, `{ "public_certs": [] }`) - })) - defer server.Close() + }) - ca := &CloudflareAccess{certs: make(map[string]string)} err := ca.fetchCerts(server.URL) if err != nil { diff --git a/internal/security/oauth.go b/internal/security/oauth.go index 0a328919..ac855ede 100644 --- a/internal/security/oauth.go +++ b/internal/security/oauth.go @@ -36,23 +36,28 @@ type OAuth2Server struct { authCodes map[string]AuthCode accessTokens map[string]AccessToken mutex sync.RWMutex + debug bool GithubConfig *oauth2.Config GoogleConfig *oauth2.Config } -func NewOAuth2Server(config *conf.Settings) *OAuth2Server { +func NewOAuth2Server() *OAuth2Server { + settings := conf.GetSettings() + debug := settings.Debug + server := &OAuth2Server{ - Settings: config, + Settings: settings, authCodes: make(map[string]AuthCode), accessTokens: make(map[string]AccessToken), + debug: debug, } // Initialize Gothic with the provided configuration - InitializeGoth(config) + InitializeGoth(settings) // Clean up expired tokens every hour - server.StartTokenCleanup(time.Hour) + server.StartAuthCleanup(time.Hour) return server } @@ -88,17 +93,20 @@ func (s *OAuth2Server) UpdateProviders() { func (s *OAuth2Server) IsUserAuthenticated(c echo.Context) bool { if token, err := gothic.GetFromSession("access_token", c.Request()); err == nil && token != "" && s.ValidateAccessToken(token) { + s.Debug("User was authenticated with valid access_token") return true } userId, _ := gothic.GetFromSession("userId", c.Request()) if s.Settings.Security.GoogleAuth.Enabled { if googleUser, _ := gothic.GetFromSession("google", c.Request()); isValidUserId(s.Settings.Security.GoogleAuth.UserId, userId) && googleUser != "" { + s.Debug("User was authenticated with valid Google user") return true } } if s.Settings.Security.GithubAuth.Enabled { if githubUser, _ := gothic.GetFromSession("github", c.Request()); isValidUserId(s.Settings.Security.GithubAuth.UserId, userId) && githubUser != "" { + s.Debug("User was authenticated with valid GitHub user") return true } } @@ -118,7 +126,6 @@ func isValidUserId(configuredIds string, providedId string) bool { } } - log.Printf("User with userId is not allowed to login: %s", providedId) return false } @@ -200,7 +207,7 @@ func (s *OAuth2Server) IsRequestFromAllowedSubnet(ip string) bool { clientIP := net.ParseIP(ip) log.Printf("*** %s", clientIP) if clientIP == nil { - log.Printf("Invalid IP address: %s", ip) + s.Debug("Invalid IP address: %s", ip) return false } @@ -210,14 +217,16 @@ func (s *OAuth2Server) IsRequestFromAllowedSubnet(ip string) bool { for _, subnet := range subnets { _, ipNet, err := net.ParseCIDR(strings.TrimSpace(subnet)) if err == nil && ipNet.Contains(clientIP) { + s.Debug("Access allowed for IP %s", clientIP) return true } } + s.Debug("IP %s is not in the allowed subnet", clientIP) return false } -func (s *OAuth2Server) StartTokenCleanup(interval time.Duration) { +func (s *OAuth2Server) StartAuthCleanup(interval time.Duration) { go func() { ticker := time.NewTicker(interval) defer ticker.Stop() @@ -241,6 +250,17 @@ func (s *OAuth2Server) StartTokenCleanup(interval time.Duration) { } s.mutex.Unlock() + s.Debug("Token & code cleanup completed") } }() } + +func (s *OAuth2Server) Debug(format string, v ...interface{}) { + if s.debug { + if len(v) == 0 { + log.Print(format) + } else { + log.Printf(format, v...) + } + } +} diff --git a/internal/security/oauth_test.go b/internal/security/oauth_test.go index d980e09f..da979ec1 100644 --- a/internal/security/oauth_test.go +++ b/internal/security/oauth_test.go @@ -15,13 +15,16 @@ import ( // TestIsUserAuthenticatedValidAccessToken tests the IsUserAuthenticated function with a valid access token func TestIsUserAuthenticatedValidAccessToken(t *testing.T) { + // Set the settings instance + conf.Setting() + settings := &conf.Settings{ Security: conf.Security{ SessionSecret: "test-secret", }, } - s := NewOAuth2Server(settings) + s := NewOAuth2Server() // Initialize gothic exactly as in production gothic.Store = sessions.NewCookieStore([]byte(settings.Security.SessionSecret)) @@ -55,6 +58,9 @@ func TestIsUserAuthenticatedValidAccessToken(t *testing.T) { // TestIsUserAuthenticatedInvalidAccessToken tests the IsUserAuthenticated function with an invalid access token func TestIsUserAuthenticated(t *testing.T) { + // Set the settings instance + conf.Setting() + tests := []struct { name string token string @@ -77,7 +83,7 @@ func TestIsUserAuthenticated(t *testing.T) { }, } - s := NewOAuth2Server(settings) + s := NewOAuth2Server() // Initialize gothic exactly as in production gothic.Store = sessions.NewCookieStore([]byte(settings.Security.SessionSecret)) @@ -109,3 +115,70 @@ func TestIsUserAuthenticated(t *testing.T) { }) } } + +func TestOAuth2Server(t *testing.T) { + // Set the settings instance + conf.Setting() + + tests := []struct { + name string + test func(*testing.T, *OAuth2Server) + }{ + { + name: "generate and validate auth code", + test: func(t *testing.T, s *OAuth2Server) { + // Initialize settings + s.Settings = &conf.Settings{ + Security: conf.Security{ + BasicAuth: conf.BasicAuth{ + Enabled: true, + ClientID: "test-client", + ClientSecret: "test-secret", + AuthCodeExp: 10 * time.Minute, + AccessTokenExp: 1 * time.Hour, + }, + }, + } + + // Generate and immediately use the auth code + code, err := s.GenerateAuthCode() + if err != nil { + t.Fatalf("Failed to generate auth code: %v", err) + } + + token, err := s.ExchangeAuthCode(code) + if err != nil { + t.Fatalf("Failed to exchange auth code: %v", err) + } + + if !s.ValidateAccessToken(token) { + t.Error("Token validation failed") + } + }, + }, + { + name: "subnet bypass validation", + test: func(t *testing.T, s *OAuth2Server) { + s.Settings.Security.AllowSubnetBypass = conf.AllowSubnetBypass{ + Enabled: true, + Subnet: "192.168.1.0/24", + } + + if !s.IsRequestFromAllowedSubnet("192.168.1.100") { + t.Error("Expected IP to be allowed") + } + + if s.IsRequestFromAllowedSubnet("10.0.0.1") { + t.Error("Expected IP to be denied") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := NewOAuth2Server() + tt.test(t, s) + }) + } +} diff --git a/views/index.html b/views/index.html index e9acbefa..3e825f6f 100644 --- a/views/index.html +++ b/views/index.html @@ -12,9 +12,6 @@ - - diff --git a/views/settings/securitySettings.html b/views/settings/securitySettings.html index d21d8540..0e84fb05 100644 --- a/views/settings/securitySettings.html +++ b/views/settings/securitySettings.html @@ -265,7 +265,9 @@ enabled: {{.Settings.Security.AllowSubnetBypass.Enabled}}, subnet: '{{.Settings.Security.AllowSubnetBypass.Subnet}}' }, - allowCloudflareBypass: {{.Settings.Security.AllowCloudflareBypass}} + allowCloudflareBypass: { + enabled: {{.Settings.Security.AllowCloudflareBypass.Enabled}} + } }, bypassAuthOpen: false, showTooltip: null, @@ -295,7 +297,7 @@ {{template "checkbox" dict "id" "cloudflareBypassEnabled" - "model" "security.allowCloudflareBypass" + "model" "security.allowCloudflareBypass.enabled" "label" "Allow Cloudflare Access to Bypass Authentication" "tooltip" "Allow users authenticated through Cloudflare Access to be automatically granted access without additional login."}}