Skip to content

Commit

Permalink
adding precompute longitude
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Jan 22, 2025
1 parent 33449fd commit f11da53
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 17 deletions.
8 changes: 4 additions & 4 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torch.autograd import gradcheck
from torch_harmonics import *

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes


def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Expand Down Expand Up @@ -108,9 +108,9 @@ def _precompute_convolution_tensor_dense(
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)

# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]
# compute the phi differences.
lons_in = _precompute_longitudes(nlon_in)
lons_out = _precompute_longitudes(nlon_out)

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from functools import partial

from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
Expand Down Expand Up @@ -178,7 +178,7 @@ def _precompute_convolution_tensor_s2(

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_in = _precompute_longitudes(nlon_in)

# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from functools import partial

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis
Expand Down Expand Up @@ -110,7 +110,7 @@ def _precompute_distributed_convolution_tensor_s2(

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_in = _precompute_longitudes(nlon_in)

# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
Expand Down
6 changes: 3 additions & 3 deletions torch_harmonics/distributed/distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes
Expand Down Expand Up @@ -81,9 +81,9 @@ def __init__(

# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = torch.linspace(0, 2 * math.pi, nlon_in+1)[:-1]
self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = torch.linspace(0, 2 * math.pi, nlon_out+1)[:-1]
self.lons_out = _precompute_longitudes(nlon_out)

# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
Expand Down
3 changes: 2 additions & 1 deletion torch_harmonics/examples/pde_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch
import torch.nn as nn
import torch_harmonics as harmonics
from torch_harmonics.quadrature import _precompute_longitudes

import math
import numpy as np
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", rad

# apply cosine transform and flip them
lats = -torch.arcsin(cost)
lons = torch.linspace(0, 2*math.pi, self.nlon+1, dtype=torch.float64)[:nlon]
lons = _precompute_longitudes(self.nlon)

self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/examples/shallow_water_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", rad

# apply cosine transform and flip them
lats = -torch.arcsin(cost)
lons = torch.linspace(0, 2*math.pi, self.nlon+1, dtype=torch.float64)[:nlon]
lons = _precompute_longitudes(self.nlon)

self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
Expand Down
6 changes: 5 additions & 1 deletion torch_harmonics/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import numpy as np
import torch


def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:

Expand All @@ -56,6 +55,11 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa

return xlg, wlg

@lru_cache(typed=True, copy=True):
def _precompute_longitudes(nlon: int):
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64)[:-1]
return lons


@lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
6 changes: 3 additions & 3 deletions torch_harmonics/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes


class ResampleS2(nn.Module):
Expand Down Expand Up @@ -67,9 +67,9 @@ def __init__(

# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = torch.linspace(0, 2 * math.pi, nlon_in+1)[:-1]
self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = torch.linspace(0, 2 * math.pi, nlon_out+1)[:-1]
self.lons_out = _precompute_longitudes(nlon_out)

# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
Expand Down

0 comments on commit f11da53

Please sign in to comment.