diff --git a/ext/InterpolationsRegridderExt.jl b/ext/InterpolationsRegridderExt.jl index f64447aa..784e6191 100644 --- a/ext/InterpolationsRegridderExt.jl +++ b/ext/InterpolationsRegridderExt.jl @@ -5,6 +5,7 @@ import Interpolations as Intp import ClimaCore import ClimaCore.Fields: Adapt import ClimaCore.Fields: ClimaComms +import ClimaCore.Fields: zeros import ClimaUtilities.Regridders @@ -12,6 +13,7 @@ struct InterpolationsRegridder{ SPACE <: ClimaCore.Spaces.AbstractSpace, FIELD <: ClimaCore.Fields.Field, BC, + GITP, } <: Regridders.AbstractRegridder """ClimaCore.Space where the output Field will be defined""" @@ -22,6 +24,10 @@ struct InterpolationsRegridder{ """Tuple of extrapolation conditions as accepted by Interpolations.jl""" extrapolation_bc::BC + + # This is needed because Adapt moves from CPU to GPU and allocates new memory + """Preallocated area of memory where to store the GPU interpolant (if needed)""" + _gpuitp::GITP end # Note, we swap Lat and Long! This is because according to the CF conventions longitude @@ -69,7 +75,12 @@ function Regridders.InterpolationsRegridder( end end - return InterpolationsRegridder(target_space, coordinates, extrapolation_bc) + return InterpolationsRegridder( + target_space, + coordinates, + extrapolation_bc, + zeros(target_space), + ) end """ @@ -90,7 +101,8 @@ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions) ) # Move it to GPU (if needed) - gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + itp._gpuitp .= + Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) return map(regridder.coordinates) do coord gpuitp(totuple(coord)...)