Skip to content

Commit

Permalink
try that
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Feb 28, 2024
1 parent 03d04e1 commit 8d4b1c3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion R/accelerator.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ LuzAcceleratorState <- R6::R6Class(

if (torch::cuda_is_available())
paste0("cuda:", index)
else if (torch::backends_mps_is_available())
else if (can_use_mps())
"mps"
else
"cpu"
Expand Down
7 changes: 6 additions & 1 deletion R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,13 @@ get_metrics.luz_module_evaluation <- function(object, ...) {
res[, c("metric", "value")]
}

can_use_mps <- function() {
arch <- Sys.info()["machine"]
"arm64" %in% arch && torch::backends_mps_is_available()
}

enable_mps_fallback <- function() {
if (!torch::backends_mps_is_available())
if (!can_use_mps())
return(invisible(NULL))

fallback <- Sys.getenv("PYTORCH_ENABLE_MPS_FALLBACK", unset = "")
Expand Down

0 comments on commit 8d4b1c3

Please sign in to comment.