diff --git a/NAMESPACE b/NAMESPACE index 8c515a7322..18e8f98f90 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -472,6 +472,7 @@ export(optim_rprop) export(optim_sgd) export(optimizer) export(sampler) +export(set_autocast) export(slc) export(tensor_dataset) export(torch_abs) @@ -846,6 +847,7 @@ export(torch_vstack) export(torch_where) export(torch_zeros) export(torch_zeros_like) +export(unset_autocast) export(with_autocast) export(with_detect_anomaly) export(with_device) diff --git a/R/autocast.R b/R/autocast.R index f010c4e7e7..e7afb4f4db 100644 --- a/R/autocast.R +++ b/R/autocast.R @@ -18,6 +18,7 @@ #' @param dtype a torch data type indicating whether to use `torch_float16()` or `torch_bfloat16()`. #' @param cache_enabled a logical value indicating whether the weight cache inside autocast should be enabled. #' @param ... currently unused. +#' @param context Returned by `set_autocast` and should be passed when unsetting it. #' @inheritParams with_no_grad #' @examples #' x <- torch_randn(5, 5, dtype = torch_float32()) @@ -34,6 +35,22 @@ #' @seealso [cuda_amp_grad_scaler()] to perform dynamic gradient scaling. #' @export local_autocast <- function(device_type, dtype = NULL, enabled = TRUE, cache_enabled = NULL, ..., .env = parent.frame()) { + context <- set_autocast(device_type, dtype = dtype, enabled = enabled, cache_enabled = cache_enabled) + withr::defer({ + unset_autocast(context) + }, envir = .env) +} + +#' @describeIn local_autocast A with context for automatic mixed precision. +#' @export +with_autocast <- function(code, ... , device_type, dtype = NULL, enabled = TRUE, cache_enabled = NULL) { + local_autocast(device_type, dtype = dtype, enabled = enabled, cache_enabled = cache_enabled) + force(code) +} + +#' @describeIn local_autocast Set the autocast context. For advanced users only. +#' @export +set_autocast <- function(device_type, dtype = NULL, enabled = TRUE, cache_enabled = NULL) { device <- device_type fast_dtype <- if (!is.null(dtype)) { @@ -73,28 +90,35 @@ local_autocast <- function(device_type, dtype = NULL, enabled = TRUE, cache_enab prev_cache_enabled <- cpp_amp_autocast_is_cache_enabled() cpp_amp_autocast_set_cache_enabled(cache_enabled) - withr::defer({ - if (device == "cpu") { - if (cpp_amp_autocast_decrease_nesting() == 0) { - cpp_amp_autocast_clear_cache() - } - cpp_amp_autocast_set_cpu_enabled(prev_enabled) - cpp_amp_autocast_set_cpu_dtype(prev_fast_dtype) - } else if (device == "cuda") { - if (cpp_amp_autocast_decrease_nesting() == 0) { - cpp_amp_autocast_clear_cache() - } - cpp_amp_autocast_set_gpu_enabled(prev_enabled) - cpp_amp_autocast_set_gpu_dtype(prev_fast_dtype) - } - }, envir = .env) + list( + device = device, + enabled = prev_enabled, + fast_dtype = prev_fast_dtype, + cache_enabled = prev_cache_enabled + ) } -#' @describeIn local_autocast A with context for automatic mixed precision. +#' @describeIn local_autocast Unset the autocast context. #' @export -with_autocast <- function(code, ... , device_type, dtype = NULL, enabled = TRUE, cache_enabled = NULL) { - local_autocast(device_type, dtype = dtype, enabled = enabled, cache_enabled = cache_enabled) - force(code) +unset_autocast <- function(context) { + device <- context$device + prev_enabled <- context$enabled + prev_fast_dtype <- context$fast_dtype + prev_cache_enabled <- context$cache_enabled + + if (device == "cpu") { + if (cpp_amp_autocast_decrease_nesting() == 0) { + cpp_amp_autocast_clear_cache() + } + cpp_amp_autocast_set_cpu_enabled(prev_enabled) + cpp_amp_autocast_set_cpu_dtype(prev_fast_dtype) + } else if (device == "cuda") { + if (cpp_amp_autocast_decrease_nesting() == 0) { + cpp_amp_autocast_clear_cache() + } + cpp_amp_autocast_set_gpu_enabled(prev_enabled) + cpp_amp_autocast_set_gpu_dtype(prev_fast_dtype) + } } #' Creates a gradient scaler diff --git a/man/local_autocast.Rd b/man/local_autocast.Rd index 7e3267b1b2..3acf1ca653 100644 --- a/man/local_autocast.Rd +++ b/man/local_autocast.Rd @@ -3,6 +3,8 @@ \name{local_autocast} \alias{local_autocast} \alias{with_autocast} +\alias{set_autocast} +\alias{unset_autocast} \title{Autocast context manager} \usage{ local_autocast( @@ -22,6 +24,10 @@ with_autocast( enabled = TRUE, cache_enabled = NULL ) + +set_autocast(device_type, dtype = NULL, enabled = TRUE, cache_enabled = NULL) + +unset_autocast(context) } \arguments{ \item{device_type}{a character string indicating whether to use 'cuda' or 'cpu' device} @@ -37,6 +43,8 @@ with_autocast( \item{.env}{The environment to use for scoping.} \item{code}{code to be executed with no gradient recording.} + +\item{context}{Returned by \code{set_autocast} and should be passed when unsetting it.} } \description{ Allow regions of your code to run in mixed precision. @@ -57,6 +65,10 @@ corresponding forward ops. \itemize{ \item \code{with_autocast()}: A with context for automatic mixed precision. +\item \code{set_autocast()}: Set the autocast context. For advanced users only. + +\item \code{unset_autocast()}: Unset the autocast context. + }} \examples{ if (torch_is_installed()) {