diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml
index be6f40e0..75f44288 100644
--- a/.github/workflows/R-CMD-check.yaml
+++ b/.github/workflows/R-CMD-check.yaml
@@ -53,18 +53,20 @@ jobs:
extra-packages: any::rcmdcheck
needs: check
+ - run: |
+ print(torch::torch_is_installed())
+ print(torch::backends_mps_is_available())
+ shell: Rscript {0}
+
- uses: r-lib/actions/check-r-package@v2
with:
error-on: '"error"'
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'
GPU:
- runs-on: ['self-hosted', 'gce', 'gpu']
+ runs-on: ['self-hosted', 'gpu-local']
name: 'gpu'
-
- container:
- image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
- options: --runtime=nvidia --gpus all
+ container: {image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04', options: '--gpus all --runtime=nvidia'}
env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
diff --git a/.github/workflows/test-coverage-pak.yaml b/.github/workflows/test-coverage-pak.yaml
index bf9f1008..2a2297e9 100644
--- a/.github/workflows/test-coverage-pak.yaml
+++ b/.github/workflows/test-coverage-pak.yaml
@@ -13,10 +13,10 @@ name: test-coverage
jobs:
test-coverage:
- runs-on: ['self-hosted', 'gce', 'gpu']
+ runs-on: ['self-hosted', 'gpu-local']
container:
- image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
+ image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
options: --gpus all
env:
diff --git a/DESCRIPTION b/DESCRIPTION
index 1cc8dcc9..3d216aa1 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: luz
Title: Higher Level 'API' for 'torch'
-Version: 0.4.0.9000
+Version: 0.4.0.9002
Authors@R: c(
person("Daniel", "Falbel", email = "daniel@rstudio.com", role = c("aut", "cre", "cph")),
person(family = "RStudio", role = c("cph"))
@@ -17,7 +17,7 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Imports:
- torch (>= 0.9.0),
+ torch (>= 0.11.9000),
magrittr,
zeallot,
rlang (>= 1.0.0),
@@ -69,4 +69,4 @@ Collate:
'reexports.R'
'serialization.R'
Remotes:
- mlverse/torch
+ mlverse/torch
diff --git a/NEWS.md b/NEWS.md
index 37501bff..b6e9b40f 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -5,6 +5,7 @@
* Fixed a bug when trying to resume models trained with learning rate schedulers. (#137)
* Added support for learning rate schedulers that take the current loss as arguments. (#140)
+
# luz 0.4.0
## Breaking changes
diff --git a/R/accelerator.R b/R/accelerator.R
index 73c0a674..51be0733 100644
--- a/R/accelerator.R
+++ b/R/accelerator.R
@@ -93,7 +93,7 @@ LuzAcceleratorState <- R6::R6Class(
if (torch::cuda_is_available())
paste0("cuda:", index)
- else if (torch::backends_mps_is_available())
+ else if (can_use_mps())
"mps"
else
"cpu"
diff --git a/R/callbacks-amp.R b/R/callbacks-amp.R
index 9a2ecc57..6f5291c0 100644
--- a/R/callbacks-amp.R
+++ b/R/callbacks-amp.R
@@ -19,7 +19,7 @@ NULL
luz_callback_mixed_precision <- luz_callback(
"mixed_precision_callback",
initialize = function(...) {
- self$autocast_env <- rlang::new_environment()
+ self$autocast_context <- NULL
self$scaler <- torch::cuda_amp_grad_scaler(...)
},
on_fit_begin = function() {
@@ -30,10 +30,11 @@ luz_callback_mixed_precision <- luz_callback(
},
on_train_batch_begin = function() {
device_type <- if (grepl("cuda", ctx$device)) "cuda" else ctx$device
- torch::local_autocast(device_type = device_type, .env = self$autocast_env)
+ self$autocast_context <- torch::set_autocast(device_type = device_type)
},
on_train_batch_after_loss = function() {
- withr::deferred_run(self$autocast_env)
+ torch::unset_autocast(self$autocast_context)
+ self$autocast_context <- NULL
},
on_train_batch_before_backward = function() {
torch::with_enable_grad({
diff --git a/R/module.R b/R/module.R
index 76e4c561..ee3c48be 100644
--- a/R/module.R
+++ b/R/module.R
@@ -603,8 +603,13 @@ get_metrics.luz_module_evaluation <- function(object, ...) {
res[, c("metric", "value")]
}
+can_use_mps <- function() {
+ arch <- Sys.info()["machine"]
+ "arm64" %in% arch && torch::backends_mps_is_available()
+}
+
enable_mps_fallback <- function() {
- if (!torch::backends_mps_is_available())
+ if (!can_use_mps())
return(invisible(NULL))
fallback <- Sys.getenv("PYTORCH_ENABLE_MPS_FALLBACK", unset = "")
diff --git a/tests/testthat/_snaps/module-plot/ggplot2-histogram.svg b/tests/testthat/_snaps/module-plot/ggplot2-histogram.svg
index ef3a51cc..42b2218c 100644
--- a/tests/testthat/_snaps/module-plot/ggplot2-histogram.svg
+++ b/tests/testthat/_snaps/module-plot/ggplot2-histogram.svg
@@ -18,7 +18,7 @@
-
+