Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cuhpx for reordering on gpu #8

Merged
merged 3 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
except ImportError:
healpixpad = None

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"]
try:
import cuhpx
except ImportError:
cuhpx = None

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder"]


def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
Expand Down Expand Up @@ -91,10 +96,57 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)


def _apply_cuhpx_remap(func, x, **kwargs):
shape = x.shape
x = x.view(-1, 1, shape[-1])
nside = npix2nside(x.shape[-1])
x = func(x.contiguous(), **kwargs, nside=nside)
x = x.contiguous()
x = x.view(shape[:-1] + (-1,))
return x


def npix2nside(npix: int):
nside = math.sqrt(npix // 12)
return int(nside)


def npix2level(npix: int):
return nside2level(npix2nside(npix))


def nside2level(nside: int):
return int(math.log2(nside))


class PixelOrder(Enum):
RING = 0
NEST = 1

def reorder_from_cuda(self, x, src: "PixelOrderT"):
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
if self == PixelOrder.RING:
return src.to_ring_cuda(x)
elif self == PixelOrder.NEST:
return src.to_nest_cuda(x)

def to_ring_cuda(self, x: torch.Tensor):
if self == PixelOrder.RING:
return x
elif self == PixelOrder.NEST:
return _apply_cuhpx_remap(cuhpx.nest2ring, x)

def to_nest_cuda(self, x: torch.Tensor):
if self == PixelOrder.RING:
return _apply_cuhpx_remap(cuhpx.ring2nest, x)
elif self == PixelOrder.NEST:
return x

def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
if self == PixelOrder.RING:
return _apply_cuhpx_remap(cuhpx.ring2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)
elif self == PixelOrder.NEST:
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
return _apply_cuhpx_remap(cuhpx.nest2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)


class Compass(Enum):
"""Cardinal directions in counter clockwise order"""
Expand Down Expand Up @@ -126,12 +178,42 @@ class XY:
origin: Compass = Compass.S
clockwise: bool = False

def reorder_from_cuda(self, x, src: "PixelOrderT"):
return src.to_xy_cuda(x, self)

def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
return _apply_cuhpx_remap(
cuhpx.flat2flat,
x,
src_origin=self.origin.name,
src_clockwise=self.clockwise,
dest_origin=dest.origin.name,
dest_clockwise=dest.clockwise,
)

def to_ring_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(
cuhpx.flat2ring,
x,
origin=self.origin.name,
clockwise=self.clockwise,
)

def to_nest_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(cuhpx.flat2nest, x, origin=self.origin.name, clockwise=self.clockwise)


PixelOrderT = Union[PixelOrder, XY]

HEALPIX_PAD_XY = XY(origin=Compass.N, clockwise=True)


def reorder(x: torch.Tensor, src_pixel_order: PixelOrderT, dest_pixel_order: PixelOrderT):
"""Reorder x from one pixel order to another"""
grid = Grid(level=npix2level(x.size(-1)), pixel_order=src_pixel_order)
return grid.reorder(dest_pixel_order, x)


def _convert_xyindex(nside: int, src: XY, dest: XY, i):
if src.clockwise != dest.clockwise:
i = _flip_xy(nside, i)
Expand Down Expand Up @@ -261,13 +343,19 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
def approximate_grid_length_meters(self):
return approx_grid_length_meters(self._nside())

def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor:
"""Rorder the pixels of ``x`` to have ``order``"""
def _reorder_cpu(self, x: torch.Tensor, order: PixelOrderT):
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
output_grid = Grid(level=self.level, pixel_order=order)
i_nest = output_grid._nest_ipix()
i_me = self._nest2me(i_nest)
return x[..., i_me]

def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor:
"""Rorder the pixels of ``x`` to have ``order``"""
if x.device.type == "cuda":
return order.reorder_from_cuda(x, self.pixel_order)
else:
return self._reorder_cpu(x, order)

def get_healpix_regridder(self, dest: "Grid"):
if self.level != dest.level:
return self.get_bilinear_regridder_to(dest.lat, dest.lon)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_rotate_index(rot):

@pytest.mark.parametrize("origin", list(healpix.Compass))
@pytest.mark.parametrize("clockwise", [True, False])
def test_reorder(tmp_path, origin, clockwise):
def test_Grid_reorder(tmp_path, origin, clockwise):
src_grid = healpix.Grid(level=4, pixel_order=healpix.XY(origin=origin, clockwise=clockwise))
dest_grid = healpix.Grid(level=4, pixel_order=healpix.PixelOrder.NEST)

Expand Down Expand Up @@ -145,6 +145,21 @@ def test_conv2d():
assert out.shape == (n, cout, 1, npix)


@pytest.mark.parametrize("nside", [16])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("src_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST])
@pytest.mark.parametrize("dest_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST])
def test_reorder(nside, src_pixel_order, dest_pixel_order, device):
# Generate some test data
if device == "cuda" and torch.cuda.device_count() == 0:
pytest.skip("no cuda devices available")

data = torch.randn(1, 2, 12 * nside * nside, device=device)
out = healpix.reorder(data, src_pixel_order, dest_pixel_order)
out = healpix.reorder(out, dest_pixel_order, src_pixel_order)
assert torch.all(data == out), data - out


def test_latlon_cuda_set_device_regression():
"""See https://github.com/NVlabs/earth2grid/issues/6"""

Expand Down
2 changes: 1 addition & 1 deletion tests/test_latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_lat_lon_bilinear_regrid_to():

regrid.float()
lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
z = torch.tensor(lat).float()
z = lat.float()

out = regrid(z)
assert out.shape == dest.shape
Loading