From 4b2ceb64fae53d290d2f8bd8a38518abd6d58cf0 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Wed, 28 Feb 2024 18:54:43 -0300 Subject: [PATCH] fix mixed precision --- DESCRIPTION | 4 +++- R/callbacks-amp.R | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 56d1328a..7d393d84 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,7 +17,7 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.3 Imports: - torch (>= 0.9.0), + torch (>= 0.12.9000), magrittr, zeallot, rlang (>= 1.0.0), @@ -68,3 +68,5 @@ Collate: 'module.R' 'reexports.R' 'serialization.R' +Remotes: + mlverse/torch diff --git a/R/callbacks-amp.R b/R/callbacks-amp.R index 9a2ecc57..6f5291c0 100644 --- a/R/callbacks-amp.R +++ b/R/callbacks-amp.R @@ -19,7 +19,7 @@ NULL luz_callback_mixed_precision <- luz_callback( "mixed_precision_callback", initialize = function(...) { - self$autocast_env <- rlang::new_environment() + self$autocast_context <- NULL self$scaler <- torch::cuda_amp_grad_scaler(...) }, on_fit_begin = function() { @@ -30,10 +30,11 @@ luz_callback_mixed_precision <- luz_callback( }, on_train_batch_begin = function() { device_type <- if (grepl("cuda", ctx$device)) "cuda" else ctx$device - torch::local_autocast(device_type = device_type, .env = self$autocast_env) + self$autocast_context <- torch::set_autocast(device_type = device_type) }, on_train_batch_after_loss = function() { - withr::deferred_run(self$autocast_env) + torch::unset_autocast(self$autocast_context) + self$autocast_context <- NULL }, on_train_batch_before_backward = function() { torch::with_enable_grad({