Skip to content

Commit

Permalink
trivial resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Jan 23, 2025
1 parent 52a63a7 commit 6b95694
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch_harmonics/distributed/distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)

self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)

def extra_repr(self):
r"""
Pretty print module
Expand Down Expand Up @@ -167,6 +169,9 @@ def _upscale_latitudes(self, x: torch.Tensor):

def forward(self, x: torch.Tensor):

if self.skip_resampling:
return x

# transpose data so that h is local, and channels are split
num_chans = x.shape[-3]

Expand Down
6 changes: 6 additions & 0 deletions torch_harmonics/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def __init__(
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)

self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)


def extra_repr(self):
r"""
Pretty print module
Expand Down Expand Up @@ -152,6 +155,9 @@ def _upscale_latitudes(self, x: torch.Tensor):
return x

def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x

if self.expand_poles:
x = self._expand_poles(x)
x = self._upscale_latitudes(x)
Expand Down

0 comments on commit 6b95694

Please sign in to comment.