Skip to content

Commit

Permalink
Adap test and way to get the current loss
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Dec 12, 2023
1 parent c8a23d7 commit 0da9c78
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
1 change: 1 addition & 0 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ luz_callback_lr_scheduler <- luz_callback(
}
self[[call_on]] <- function() {
if ("metrics" %in% names(formals(self$scheduler$step))) {
current_loss <- ctx$loss[[self$opt_name]]
self$scheduler$step(current_loss)
} else {
self$scheduler$step()
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/_snaps/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
Adjusting learning rate of group 1 to 0.0001
Adjusting learning rate of group 1 to 0.0000

---

Code
expect_message({
output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl,
verbose = FALSE, epochs = 20, callbacks = list(luz_callback_lr_scheduler(
torch::lr_reduce_on_plateau, verbose = TRUE, patience = 2, threshold = 0.1)))
})
Message
Epoch 7: reducing learning rate of group 1 to 1.0000e-05
Epoch 10: reducing learning rate of group 1 to 1.0000e-06
Epoch 13: reducing learning rate of group 1 to 1.0000e-07
Epoch 16: reducing learning rate of group 1 to 1.0000e-08

# progressbar appears with training and validation

Code
Expand Down
9 changes: 7 additions & 2 deletions tests/testthat/test-callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ test_that("callback lr scheduler", {
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 5, callbacks = list(
luz_callback_lr_scheduler(torch::lr_reduce_on_plateau, verbose = TRUE)
fit(dl, verbose = FALSE, epochs = 20, callbacks = list(
luz_callback_lr_scheduler(
torch::lr_reduce_on_plateau,
verbose = TRUE,
patience = 2,
threshold = 1e-1
)
))
})
})
Expand Down

0 comments on commit 0da9c78

Please sign in to comment.