Skip to content

Commit

Permalink
adjust for MPS in Apple M3
Browse files Browse the repository at this point in the history
  • Loading branch information
gilbertocamara committed Dec 11, 2024
1 parent 483e153 commit 814a7ff
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
45 changes: 24 additions & 21 deletions R/sits_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,27 +239,6 @@ sits_classify.raster_cube <- function(data,
# version is case-insensitive in sits
version <- .check_version(version)
.check_progress(progress)
# Get default proc bloat
proc_bloat <- .conf("processing_bloat_cpu")
# If we using the GPU, gpu_memory parameter needs to be specified
if (.torch_cuda_enabled(ml_model)) {
.check_int_parameter(gpu_memory, min = 1, max = 16384,
msg = .conf("messages", ".check_gpu_memory")
)
# Calculate available memory from GPU
memsize <- floor(gpu_memory - .torch_mem_info())
.check_int_parameter(memsize, min = 1,
msg = .conf("messages", ".check_gpu_memory_size")
)
proc_bloat <- .conf("processing_bloat_gpu")
}
# avoid memory race in Apple MPS
if (.torch_mps_enabled(ml_model)) {
memsize <- 1
gpu_memory <- 1
}
# save memsize for latter use
sits_env[["gpu_memory"]] <- gpu_memory
# Spatial filter
if (.has(roi)) {
roi <- .roi_as_sf(roi)
Expand Down Expand Up @@ -293,8 +272,11 @@ sits_classify.raster_cube <- function(data,
.check_samples_tile_match_timeline(samples = samples, tile = data)
# Do the samples and tile match their bands?
.check_samples_tile_match_bands(samples = samples, tile = data)

# Get block size
block <- .raster_file_blocksize(.raster_open_rast(.tile_path(data)))
# Get default proc bloat
proc_bloat <- .conf("processing_bloat_cpu")
# Check minimum memory needed to process one block
job_memsize <- .jobs_memsize(
job_size = .block_size(block = block, overlap = 0),
Expand All @@ -310,6 +292,27 @@ sits_classify.raster_cube <- function(data,
nbytes = 8,
proc_bloat = proc_bloat
)

# If we using the GPU, gpu_memory parameter needs to be specified
if (.torch_cuda_enabled(ml_model)) {
.check_int_parameter(gpu_memory, min = 1, max = 16384,
msg = .conf("messages", ".check_gpu_memory")
)
# Calculate available memory from GPU
memsize <- floor(gpu_memory - .torch_mem_info())
.check_int_parameter(memsize, min = 1,
msg = .conf("messages", ".check_gpu_memory_size")
)
proc_bloat <- .conf("processing_bloat_gpu")
}
# avoid memory race in Apple MPS
if (.torch_mps_enabled(ml_model)) {
warning(.conf("messages", "sits_classify_mps"),
call. = FALSE
)
}
# save memsize for latter use
sits_env[["gpu_memory"]] <- gpu_memory
# Update multicores parameter
multicores <- .jobs_max_multicores(
job_memsize = job_memsize,
Expand Down
1 change: 1 addition & 0 deletions inst/extdata/config_messages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ sits_bbox: "invalid bounding box - check input data"
sits_bbox_default: "input should be an object of class sits or raster_cube"
sits_classify_default: "input should be a valid set of training samples or a non-classified data cube"
sits_classify_derived_cube: "input data cube has already been classified"
sits_classify_mps: "using MPS - please check parameters memsize and gpu_memory \n MPS shares memory with gpu \n sum of memsize with gpu_memory must be less than total available RAM"
sits_classify_tbl_df: "input should be a sits tibble or a regular non-classified data cube"
sits_classify_sits: "wrong input parameters - see example in documentation"
sits_classify_raster: "wrong input parameters - see example in documentation"
Expand Down

0 comments on commit 814a7ff

Please sign in to comment.