Skip to content

Commit

Permalink
adding requires_grad = False to precomp routines
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Jan 22, 2025
1 parent 9826c07 commit 52a63a7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 30 deletions.
7 changes: 0 additions & 7 deletions torch_harmonics/distributed/distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def __init__(
lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1)

# convert to tensor
#lat_idx = torch.LongTensor(lat_idx)

# register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)
Expand All @@ -123,10 +120,6 @@ def __init__(
diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)

# convert to tensor
#lon_idx_left = torch.LongTensor(lon_idx_left)
#lon_idx_right = torch.LongTensor(lon_idx_right)

# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",

# compute the tensor P^m_n:
nmax = max(mmax,lmax)
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64)
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False)

norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor
Expand Down Expand Up @@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,

pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)

dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64)
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False)

# fill the derivative terms wrt theta
for l in range(0, lmax):
Expand Down
32 changes: 18 additions & 14 deletions torch_harmonics/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa

@lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int):
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64)[:-1]
r"""
Convenience routine to precompute longitudes
"""

lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons


Expand All @@ -66,15 +70,15 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
r"""
Convenience routine to precompute latitudes
"""

# compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)

# to perform the quadrature and account for the jacobian of the sphere, the quadrature rule
# is formulated in the cosine theta domain, which is designed to integrate functions of cos theta
lats = torch.flip(torch.arccos(xlg), dims=(0,)).clone()
wlg = torch.flip(wlg, dims=(0,)).clone()

return lats, wlg


Expand All @@ -85,7 +89,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
"""

xlg = torch.from_numpy(np.linspace(a, b, n, endpoint=periodic))
wlg = (b - a) / (n - periodic * 1) * torch.ones(n)
wlg = (b - a) / (n - periodic * 1) * torch.ones(n, requires_grad=False)

if not periodic:
wlg[0] *= 0.5
Expand Down Expand Up @@ -116,12 +120,12 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
on the interval [a, b]
"""

wlg = torch.zeros((n,), dtype=torch.float64)
tlg = torch.zeros((n,), dtype=torch.float64)
tmp = torch.zeros((n,), dtype=torch.float64)
wlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tmp = torch.zeros((n,), dtype=torch.float64, requires_grad=False)

# Vandermonde Matrix
vdm = torch.zeros((n, n), dtype=torch.float64)
vdm = torch.zeros((n, n), dtype=torch.float64, requires_grad=False)

# initialize Chebyshev nodes as first guess
for i in range(n):
Expand Down Expand Up @@ -162,7 +166,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]

assert n > 1

tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64))
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))

if n == 2:
wcc = torch.tensor([1.0, 1.0], dtype=torch.float64)
Expand All @@ -173,11 +177,11 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
l = len(N)
m = n1 - l

v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64)])
v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
#v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))

g0 = -torch.ones(n1, dtype=torch.float64)
g0 = -torch.ones(n1, dtype=torch.float64, requires_grad=False)
g0[l] = g0[l] + n1
g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1 % 2))
Expand All @@ -201,14 +205,14 @@ def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> T

assert n > 2

tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64))
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))

n1 = n - 1
N = torch.arange(1, n1, 2, dtype=torch.float64)
l = len(N)
m = n1 - l

v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64)])
v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
#v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))

Expand Down
7 changes: 0 additions & 7 deletions torch_harmonics/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def __init__(
lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1)

# convert to tensor
#lat_idx = torch.LongTensor(lat_idx)

# register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)
Expand All @@ -109,10 +106,6 @@ def __init__(
diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)

# convert to tensor
#lon_idx_left = torch.LongTensor(lon_idx_left)
#lon_idx_right = torch.LongTensor(lon_idx_right)

# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
Expand Down

0 comments on commit 52a63a7

Please sign in to comment.