diff --git a/R/callbacks.R b/R/callbacks.R index 9ad2d77f..7f0622b4 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -435,7 +435,12 @@ luz_callback_lr_scheduler <- luz_callback( lr_scheduler(optimizer, ...) } self[[call_on]] <- function() { - self$scheduler$step() + if ("metrics" %in% names(formals(self$scheduler$step))) { + current_loss <- ctx$loss[[self$opt_name]] + self$scheduler$step(current_loss) + } else { + self$scheduler$step() + } } self$opt_name <- opt_name }, diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md index 148c388f..604b7a50 100644 --- a/tests/testthat/_snaps/callbacks.md +++ b/tests/testthat/_snaps/callbacks.md @@ -13,6 +13,20 @@ Adjusting learning rate of group 1 to 0.0001 Adjusting learning rate of group 1 to 0.0000 +--- + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = FALSE, epochs = 20, callbacks = list(luz_callback_lr_scheduler( + torch::lr_reduce_on_plateau, verbose = TRUE, patience = 2, threshold = 0.1))) + }) + Message + Epoch 7: reducing learning rate of group 1 to 1.0000e-05 + Epoch 10: reducing learning rate of group 1 to 1.0000e-06 + Epoch 13: reducing learning rate of group 1 to 1.0000e-07 + Epoch 16: reducing learning rate of group 1 to 1.0000e-08 + # progressbar appears with training and validation Code diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R index c2b01f7d..28d7d91e 100644 --- a/tests/testthat/test-callbacks.R +++ b/tests/testthat/test-callbacks.R @@ -25,6 +25,21 @@ test_that("callback lr scheduler", { }) }) + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = FALSE, epochs = 20, callbacks = list( + luz_callback_lr_scheduler( + torch::lr_reduce_on_plateau, + verbose = TRUE, + patience = 2, + threshold = 1e-1 + ) + )) + }) + }) + }) test_that("csv callback", {