Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve configuration reload #378

Merged
merged 3 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/realtime/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/spf13/viper"
"github.com/tphakala/birdnet-go/internal/analysis"
"github.com/tphakala/birdnet-go/internal/conf"
"github.com/tphakala/birdnet-go/internal/httpcontroller/handlers"
)

// RealtimeCommand creates a new command for real-time audio analysis.
Expand All @@ -17,7 +18,8 @@ func Command(settings *conf.Settings) *cobra.Command {
Short: "Analyze audio in realtime mode",
Long: "Start analyzing incoming audio data in real-time looking for bird calls.",
RunE: func(cmd *cobra.Command, args []string) error {
return analysis.RealtimeAnalysis(settings)
notificationChan := make(chan handlers.Notification, 10)
return analysis.RealtimeAnalysis(settings, notificationChan)
},
}

Expand Down
43 changes: 40 additions & 3 deletions internal/analysis/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/tphakala/birdnet-go/internal/datastore"
"github.com/tphakala/birdnet-go/internal/diskmanager"
"github.com/tphakala/birdnet-go/internal/httpcontroller"
"github.com/tphakala/birdnet-go/internal/httpcontroller/handlers"
"github.com/tphakala/birdnet-go/internal/myaudio"
"github.com/tphakala/birdnet-go/internal/telemetry"
"github.com/tphakala/birdnet-go/internal/weather"
Expand All @@ -28,7 +29,7 @@ import (
var audioLevelChan = make(chan myaudio.AudioLevelData, 100)

// RealtimeAnalysis initiates the BirdNET Analyzer in real-time mode and waits for a termination signal.
func RealtimeAnalysis(settings *conf.Settings) error {
func RealtimeAnalysis(settings *conf.Settings, notificationChan chan handlers.Notification) error {
// Initialize BirdNET interpreter
if err := initializeBirdNET(settings); err != nil {
return err
Expand Down Expand Up @@ -153,7 +154,7 @@ func RealtimeAnalysis(settings *conf.Settings) error {
startTelemetryEndpoint(&wg, settings, metrics, quitChan)

// start control monitor for hot reloads
startControlMonitor(&wg, controlChan, quitChan)
startControlMonitor(&wg, controlChan, quitChan, notificationChan)

// start quit signal monitor
monitorCtrlC(quitChan)
Expand Down Expand Up @@ -327,7 +328,7 @@ func initBirdImageCache(ds datastore.Interface, metrics *telemetry.Metrics) *ima
}

// startControlMonitor handles various control signals for realtime analysis mode
func startControlMonitor(wg *sync.WaitGroup, controlChan chan string, quitChan chan struct{}) {
func startControlMonitor(wg *sync.WaitGroup, controlChan chan string, quitChan chan struct{}, notificationChan chan handlers.Notification) {
wg.Add(1)
go func() {
defer wg.Done()
Expand All @@ -338,8 +339,44 @@ func startControlMonitor(wg *sync.WaitGroup, controlChan chan string, quitChan c
case "rebuild_range_filter":
if err := birdnet.BuildRangeFilter(bn); err != nil {
log.Printf("\033[31m❌ Error handling range filter rebuild: %v\033[0m", err)
notificationChan <- handlers.Notification{
Message: fmt.Sprintf("Failed to rebuild range filter: %v", err),
Type: "error",
}
Comment on lines +342 to +345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider error handling for channel operations.

The notification channel writes could block if the channel is full. Consider adding a select statement with a timeout.

-						notificationChan <- handlers.Notification{
-							Message: fmt.Sprintf("Failed to rebuild range filter: %v", err),
-							Type:    "error",
-						}
+						select {
+						case notificationChan <- handlers.Notification{
+							Message: fmt.Sprintf("Failed to rebuild range filter: %v", err),
+							Type:    "error",
+						}:
+						case <-time.After(time.Second):
+							log.Printf("Warning: Notification channel blocked")
+						}

Also applies to: 348-351

} else {
log.Printf("\033[32m🔄 Range filter rebuilt successfully\033[0m")
notificationChan <- handlers.Notification{
Message: "Range filter rebuilt successfully",
Type: "success",
}
}
case "reload_birdnet":
if err := bn.ReloadModel(); err != nil {
log.Printf("\033[31m❌ Error reloading BirdNET model: %v\033[0m", err)
notificationChan <- handlers.Notification{
Message: fmt.Sprintf("Failed to reload BirdNET model: %v", err),
Type: "error",
}
} else {
log.Printf("\033[32m✅ BirdNET model reloaded successfully\033[0m")
notificationChan <- handlers.Notification{
Message: "BirdNET model reloaded successfully",
Type: "success",
}
// Rebuild range filter after model reload
Comment on lines +354 to +366
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add timeout handling for model reload operation.

The model reload operation could potentially be time-consuming. Consider adding a timeout to prevent blocking.

+					ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+					defer cancel()
+					
+					errChan := make(chan error, 1)
+					go func() {
+						errChan <- bn.ReloadModel()
+					}()
+
+					select {
+					case err := <-errChan:
						if err != nil {
							// ... error handling ...
						} else {
							// ... success handling ...
						}
+					case <-ctx.Done():
+						notificationChan <- handlers.Notification{
+							Message: "Model reload timed out after 30 seconds",
+							Type:    "error",
+						}
+					}

Committable suggestion skipped: line range outside the PR's diff.

if err := birdnet.BuildRangeFilter(bn); err != nil {
log.Printf("\033[31m❌ Error rebuilding range filter after model reload: %v\033[0m", err)
notificationChan <- handlers.Notification{
Message: fmt.Sprintf("Failed to rebuild range filter: %v", err),
Type: "error",
}
} else {
log.Printf("\033[32m✅ Range filter rebuilt successfully\033[0m")
notificationChan <- handlers.Notification{
Message: "Range filter rebuilt successfully",
Type: "success",
}
}
}
default:
log.Printf("Received unknown control signal: %v", signal)
Expand Down
98 changes: 95 additions & 3 deletions internal/birdnet/birdnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"runtime"
"strings"
"sync"
"time"

"github.com/tphakala/birdnet-go/internal/conf"
"github.com/tphakala/go-tflite"
Expand Down Expand Up @@ -45,7 +44,6 @@ type BirdNET struct {
AnalysisInterpreter *tflite.Interpreter
RangeInterpreter *tflite.Interpreter
Settings *conf.Settings
SpeciesListUpdated time.Time // Timestamp for the last update of the species list.
mu sync.Mutex
}

Expand Down Expand Up @@ -301,11 +299,105 @@ func (bn *BirdNET) loadModel() ([]byte, error) {
modelPath := bn.Settings.BirdNET.ModelPath
data, err := os.ReadFile(modelPath)
if err != nil {
return nil, fmt.Errorf("failed to read custom model file: %w", err)
return nil, fmt.Errorf("failed to read model file: %w", err)
}
return data, nil
}

// validateModelAndLabels checks if the number of labels matches the model's output size
func (bn *BirdNET) validateModelAndLabels() error {
// Get the output tensor to check its dimensions
outputTensor := bn.AnalysisInterpreter.GetOutputTensor(0)
if outputTensor == nil {
return fmt.Errorf("cannot get output tensor")
}

// Get the number of classes from the model's output tensor
modelOutputSize := outputTensor.Dim(outputTensor.NumDims() - 1)

// Compare with the number of labels
if len(bn.Settings.BirdNET.Labels) != modelOutputSize {
return fmt.Errorf("\033[31m❌ label count mismatch: model expects %d classes but label file has %d labels\033[0m",
modelOutputSize, len(bn.Settings.BirdNET.Labels))
}
Comment on lines +320 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid including ANSI color codes and emojis in error messages

Including ANSI color codes and emojis in error messages can lead to unreadable logs or issues in environments that do not support ANSI escape sequences. It is recommended to keep error messages plain to ensure they are properly displayed and logged across different systems.

Apply this diff to remove color codes and emojis from error messages:

-		return fmt.Errorf("\033[31m❌ label count mismatch: model expects %d classes but label file has %d labels\033[0m",
+		return fmt.Errorf("label count mismatch: model expects %d classes but label file has %d labels",
 			modelOutputSize, len(bn.Settings.BirdNET.Labels))

...

-		return fmt.Errorf("\033[31m❌ failed to reload model: %w\033[0m", err)
+		return fmt.Errorf("failed to reload model: %w", err)

...

-		return fmt.Errorf("\033[31m❌ failed to reload meta model: %w\033[0m", err)
+		return fmt.Errorf("failed to reload meta model: %w", err)

...

-		return fmt.Errorf("\033[31m❌ failed to reload labels: %w\033[0m", err)
+		return fmt.Errorf("failed to reload labels: %w", err)

...

-		return fmt.Errorf("\033[31m❌ model validation failed: %w\033[0m", err)
+		return fmt.Errorf("model validation failed: %w", err)

Also applies to: 341-341, 354-354, 370-370, 386-386


bn.Debug("\033[32m✅ Model validation successful: %d labels match model output size\033[0m", modelOutputSize)
return nil
}
Comment on lines +308 to +326
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure NumDims() is greater than zero to prevent potential panic

In the validateModelAndLabels method, calling outputTensor.Dim(outputTensor.NumDims() - 1) without checking if NumDims() is greater than zero could cause a panic if the tensor has zero dimensions. Please add a check to ensure NumDims() is greater than zero before accessing tensor dimensions.

Apply this diff to fix the issue:

 func (bn *BirdNET) validateModelAndLabels() error {
 	// Get the output tensor to check its dimensions
 	outputTensor := bn.AnalysisInterpreter.GetOutputTensor(0)
 	if outputTensor == nil {
 		return fmt.Errorf("cannot get output tensor")
 	}

+	if outputTensor.NumDims() == 0 {
+		return fmt.Errorf("output tensor has no dimensions")
+	}

 	// Get the number of classes from the model's output tensor
 	modelOutputSize := outputTensor.Dim(outputTensor.NumDims() - 1)

 	// Compare with the number of labels
 	if len(bn.Settings.BirdNET.Labels) != modelOutputSize {
 		return fmt.Errorf("\033[31m❌ label count mismatch: model expects %d classes but label file has %d labels\033[0m",
 			modelOutputSize, len(bn.Settings.BirdNET.Labels))
 	}

 	bn.Debug("\033[32m✅ Model validation successful: %d labels match model output size\033[0m", modelOutputSize)
 	return nil
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (bn *BirdNET) validateModelAndLabels() error {
// Get the output tensor to check its dimensions
outputTensor := bn.AnalysisInterpreter.GetOutputTensor(0)
if outputTensor == nil {
return fmt.Errorf("cannot get output tensor")
}
// Get the number of classes from the model's output tensor
modelOutputSize := outputTensor.Dim(outputTensor.NumDims() - 1)
// Compare with the number of labels
if len(bn.Settings.BirdNET.Labels) != modelOutputSize {
return fmt.Errorf("\033[31m❌ label count mismatch: model expects %d classes but label file has %d labels\033[0m",
modelOutputSize, len(bn.Settings.BirdNET.Labels))
}
bn.Debug("\033[32m✅ Model validation successful: %d labels match model output size\033[0m", modelOutputSize)
return nil
}
func (bn *BirdNET) validateModelAndLabels() error {
// Get the output tensor to check its dimensions
outputTensor := bn.AnalysisInterpreter.GetOutputTensor(0)
if outputTensor == nil {
return fmt.Errorf("cannot get output tensor")
}
if outputTensor.NumDims() == 0 {
return fmt.Errorf("output tensor has no dimensions")
}
// Get the number of classes from the model's output tensor
modelOutputSize := outputTensor.Dim(outputTensor.NumDims() - 1)
// Compare with the number of labels
if len(bn.Settings.BirdNET.Labels) != modelOutputSize {
return fmt.Errorf("\033[31m❌ label count mismatch: model expects %d classes but label file has %d labels\033[0m",
modelOutputSize, len(bn.Settings.BirdNET.Labels))
}
bn.Debug("\033[32m✅ Model validation successful: %d labels match model output size\033[0m", modelOutputSize)
return nil
}


// ReloadModel safely reloads the BirdNET model and labels while handling ongoing analysis
func (bn *BirdNET) ReloadModel() error {
bn.Debug("\033[33m🔒 Acquiring mutex for model reload\033[0m")
bn.mu.Lock()
defer bn.mu.Unlock()
bn.Debug("\033[32m✅ Acquired mutex for model reload\033[0m")

// Store old interpreters to clean up after successful reload
oldAnalysisInterpreter := bn.AnalysisInterpreter
oldRangeInterpreter := bn.RangeInterpreter

// Initialize new model
if err := bn.initializeModel(); err != nil {
return fmt.Errorf("\033[31m❌ failed to reload model: %w\033[0m", err)
}
bn.Debug("\033[32m✅ Model initialized successfully\033[0m")

// Initialize new meta model
if err := bn.initializeMetaModel(); err != nil {
// Clean up the newly created analysis interpreter if meta model fails
if bn.AnalysisInterpreter != nil {
bn.AnalysisInterpreter.Delete()
}
// Restore the old interpreters
bn.AnalysisInterpreter = oldAnalysisInterpreter
bn.RangeInterpreter = oldRangeInterpreter
return fmt.Errorf("\033[31m❌ failed to reload meta model: %w\033[0m", err)
}
bn.Debug("\033[32m✅ Meta model initialized successfully\033[0m")

// Reload labels
if err := bn.loadLabels(); err != nil {
// Clean up the newly created interpreters if label loading fails
if bn.AnalysisInterpreter != nil {
bn.AnalysisInterpreter.Delete()
}
if bn.RangeInterpreter != nil {
bn.RangeInterpreter.Delete()
}
// Restore the old interpreters
bn.AnalysisInterpreter = oldAnalysisInterpreter
bn.RangeInterpreter = oldRangeInterpreter
return fmt.Errorf("\033[31m❌ failed to reload labels: %w\033[0m", err)
}
bn.Debug("\033[32m✅ Labels loaded successfully\033[0m")

// Validate that the model and labels match
if err := bn.validateModelAndLabels(); err != nil {
// Clean up the newly created interpreters if validation fails
if bn.AnalysisInterpreter != nil {
bn.AnalysisInterpreter.Delete()
}
if bn.RangeInterpreter != nil {
bn.RangeInterpreter.Delete()
}
// Restore the old interpreters
bn.AnalysisInterpreter = oldAnalysisInterpreter
bn.RangeInterpreter = oldRangeInterpreter
return fmt.Errorf("\033[31m❌ model validation failed: %w\033[0m", err)
}

// Clean up old interpreters after successful reload
if oldAnalysisInterpreter != nil {
oldAnalysisInterpreter.Delete()
}
if oldRangeInterpreter != nil {
oldRangeInterpreter.Delete()
}

bn.Debug("\033[32m✅ Model reload completed successfully\033[0m")
return nil
}

// Debug prints debug messages if debug mode is enabled
func (bn *BirdNET) Debug(format string, v ...interface{}) {
if bn.Settings.BirdNET.Debug {
Expand Down
4 changes: 3 additions & 1 deletion internal/httpcontroller/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Handlers struct {
AudioLevelChan chan myaudio.AudioLevelData // Channel for audio level updates
OAuth2Server *security.OAuth2Server
controlChan chan string
notificationChan chan Notification
}

// HandlerError is a custom error type that includes an HTTP status code and a user-friendly message.
Expand Down Expand Up @@ -72,7 +73,7 @@ func (bh *baseHandler) logInfo(message string) {
}

// New creates a new Handlers instance with the given dependencies.
func New(ds datastore.Interface, settings *conf.Settings, dashboardSettings *conf.Dashboard, birdImageCache *imageprovider.BirdImageCache, logger *log.Logger, sunCalc *suncalc.SunCalc, audioLevelChan chan myaudio.AudioLevelData, oauth2Server *security.OAuth2Server, controlChan chan string) *Handlers {
func New(ds datastore.Interface, settings *conf.Settings, dashboardSettings *conf.Dashboard, birdImageCache *imageprovider.BirdImageCache, logger *log.Logger, sunCalc *suncalc.SunCalc, audioLevelChan chan myaudio.AudioLevelData, oauth2Server *security.OAuth2Server, controlChan chan string, notificationChan chan Notification) *Handlers {
if logger == nil {
logger = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
}
Expand All @@ -91,6 +92,7 @@ func New(ds datastore.Interface, settings *conf.Settings, dashboardSettings *con
AudioLevelChan: audioLevelChan,
OAuth2Server: oauth2Server,
controlChan: controlChan,
notificationChan: notificationChan,
}
}

Expand Down
58 changes: 54 additions & 4 deletions internal/httpcontroller/handlers/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,30 @@ func (h *Handlers) SaveSettings(c echo.Context) error {
return h.NewHandlerError(err, "Error updating settings", http.StatusInternalServerError)
}

// Check if BirdNET settings have changed
if birdnetSettingsChanged(oldSettings, *settings) {
h.SSE.SendNotification(Notification{
Message: "Reloading BirdNET model...",
Type: "info",
})

h.controlChan <- "reload_birdnet"
}

// Check if range filter related settings have changed
if rangeFilterSettingsChanged(oldSettings, *settings) {
//log.Println("Range filter settings changed, sending reload signal")
h.SSE.SendNotification(Notification{
Message: "Rebuilding range filter...",
Type: "info",
})
h.controlChan <- "rebuild_range_filter"
}

// Check the authentication settings and update if needed
h.updateAuthenticationSettings(settings)

// Check if audio equalizer settings have changed
if equalizerSettingsChanged(settings.Realtime.Audio.Equalizer, settings.Realtime.Audio.Equalizer) {
//log.Println("Debug (SaveSettings): Equalizer settings changed, reloading audio filters")
if equalizerSettingsChanged(oldSettings.Realtime.Audio.Equalizer, settings.Realtime.Audio.Equalizer) {
if err := myaudio.UpdateFilterChain(settings); err != nil {
h.SSE.SendNotification(Notification{
Message: fmt.Sprintf("Error updating audio EQ filters: %v", err),
Expand All @@ -88,7 +100,6 @@ func (h *Handlers) SaveSettings(c echo.Context) error {

// Save settings to YAML file
if err := conf.SaveSettings(); err != nil {
// Send error notification if saving settings fails
h.SSE.SendNotification(Notification{
Message: fmt.Sprintf("Error saving settings: %v", err),
Type: "error",
Expand Down Expand Up @@ -574,5 +585,44 @@ func rangeFilterSettingsChanged(oldSettings, currentSettings conf.Settings) bool
return true
}

// Check for changes in BirdNET range filter settings
if !reflect.DeepEqual(oldSettings.BirdNET.RangeFilter, currentSettings.BirdNET.RangeFilter) {
return true
}

// Check for changes in BirdNET latitude and longitude
if oldSettings.BirdNET.Latitude != currentSettings.BirdNET.Latitude || oldSettings.BirdNET.Longitude != currentSettings.BirdNET.Longitude {
return true
}

return false
}

func birdnetSettingsChanged(oldSettings, currentSettings conf.Settings) bool {
// Check for changes in BirdNET locale
if oldSettings.BirdNET.Locale != currentSettings.BirdNET.Locale {
return true
}

// Check for changes in BirdNET threads
if oldSettings.BirdNET.Threads != currentSettings.BirdNET.Threads {
return true
}

// Check for changes in BirdNET model path
if oldSettings.BirdNET.ModelPath != currentSettings.BirdNET.ModelPath {
return true
}

// Check for changes in BirdNET label path
if oldSettings.BirdNET.LabelPath != currentSettings.BirdNET.LabelPath {
return true
}

// Check for changes in BirdNET XNNPACK acceleration
if oldSettings.BirdNET.UseXNNPACK != currentSettings.BirdNET.UseXNNPACK {
return true
}

return false
}
4 changes: 3 additions & 1 deletion internal/httpcontroller/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Server struct {
SunCalc *suncalc.SunCalc
AudioLevelChan chan myaudio.AudioLevelData
controlChan chan string
notificationChan chan handlers.Notification

// Page and partial routes
pageRoutes map[string]PageRouteConfig
Expand All @@ -54,6 +55,7 @@ func New(settings *conf.Settings, dataStore datastore.Interface, birdImageCache
OAuth2Server: security.NewOAuth2Server(),
CloudflareAccess: security.NewCloudflareAccess(),
controlChan: controlChan,
notificationChan: make(chan handlers.Notification, 10),
}

// Configure an IP extractor
Expand All @@ -63,7 +65,7 @@ func New(settings *conf.Settings, dataStore datastore.Interface, birdImageCache
s.SunCalc = suncalc.NewSunCalc(settings.BirdNET.Latitude, settings.BirdNET.Longitude)

// Initialize handlers
s.Handlers = handlers.New(s.DS, s.Settings, s.DashboardSettings, s.BirdImageCache, nil, s.SunCalc, s.AudioLevelChan, s.OAuth2Server, s.controlChan)
s.Handlers = handlers.New(s.DS, s.Settings, s.DashboardSettings, s.BirdImageCache, nil, s.SunCalc, s.AudioLevelChan, s.OAuth2Server, s.controlChan, s.notificationChan)

s.initializeServer()
return s
Expand Down
Loading