Skip to content

Commit

Permalink
fix mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Feb 28, 2024
1 parent 9be5427 commit 4b2ceb6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Imports:
torch (>= 0.9.0),
torch (>= 0.12.9000),
magrittr,
zeallot,
rlang (>= 1.0.0),
Expand Down Expand Up @@ -68,3 +68,5 @@ Collate:
'module.R'
'reexports.R'
'serialization.R'
Remotes:
mlverse/torch
7 changes: 4 additions & 3 deletions R/callbacks-amp.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ NULL
luz_callback_mixed_precision <- luz_callback(
"mixed_precision_callback",
initialize = function(...) {
self$autocast_env <- rlang::new_environment()
self$autocast_context <- NULL
self$scaler <- torch::cuda_amp_grad_scaler(...)
},
on_fit_begin = function() {
Expand All @@ -30,10 +30,11 @@ luz_callback_mixed_precision <- luz_callback(
},
on_train_batch_begin = function() {
device_type <- if (grepl("cuda", ctx$device)) "cuda" else ctx$device
torch::local_autocast(device_type = device_type, .env = self$autocast_env)
self$autocast_context <- torch::set_autocast(device_type = device_type)
},
on_train_batch_after_loss = function() {
withr::deferred_run(self$autocast_env)
torch::unset_autocast(self$autocast_context)
self$autocast_context <- NULL
},
on_train_batch_before_backward = function() {
torch::with_enable_grad({
Expand Down

0 comments on commit 4b2ceb6

Please sign in to comment.