Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tkurth/torchification #66

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
162 changes: 147 additions & 15 deletions notebooks/resample_sphere.ipynb

Large diffs are not rendered by default.

55 changes: 28 additions & 27 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape

lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)

# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
Expand All @@ -119,9 +117,9 @@ def _precompute_convolution_tensor_dense(

# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0

# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
Expand Down Expand Up @@ -177,29 +175,29 @@ def setUp(self):
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3], "zernike", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3], "zernike", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
]
)
def test_disco_convolution(
Expand All @@ -220,7 +218,10 @@ def test_disco_convolution(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
Expand Down
24 changes: 12 additions & 12 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,18 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):

@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
]
)
def test_distributed_disco_conv(
Expand Down
14 changes: 7 additions & 7 deletions tests/test_distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def _gather_helper_bwd(self, tensor, B, C, resampling_dist):

@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the verbose option exist in all tests now? Would be good to have it for consistency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all tests, just the ones which previously printed things

):

B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_distributed_resampling(
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == )0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)

Expand All @@ -259,7 +259,7 @@ def test_distributed_resampling(
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)

err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == 0):
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)

Expand Down
72 changes: 38 additions & 34 deletions tests/test_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import unittest
from parameterized import parameterized
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
Expand All @@ -41,31 +40,32 @@
class TestLegendrePolynomials(unittest.TestCase):

def setUp(self):
self.cml = lambda m, l: np.sqrt((2 * l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()

# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x: np.ones_like(x)
self.pml[(0, 0)] = lambda x: torch.ones_like(x)
self.pml[(0, 1)] = lambda x: x
self.pml[(1, 1)] = lambda x: -np.sqrt(1.0 - x**2)
self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
self.pml[(1, 2)] = lambda x: -3 * x * np.sqrt(1.0 - x**2)
self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * np.sqrt(1.0 - x**2)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x: -15 * np.sqrt(1.0 - x**2) ** 3
self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3

self.lmax = self.mmax = 4

self.tol = 1e-9

def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
def test_legendre(self, verbose=False):
if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly

t = np.linspace(0, 1, 100)
t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t)

for l in range(self.lmax):
Expand All @@ -87,16 +87,17 @@ def setUp(self):

@parameterized.expand(
[
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
[256, 512, 32, "ortho", "equiangular", 1e-9, False],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
[256, 512, 32, "four-pi", "equiangular", 1e-9, False],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
[256, 512, 32, "schmidt", "equiangular", 1e-9, False],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
]
)
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

testiters = [1, 2, 4, 8, 16]
if grid == "equiangular":
Expand All @@ -116,35 +117,38 @@ def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
# testing error accumulation
for iter in testiters:
with self.subTest(i=iter):
print(f"{iter} iterations of batchsize {batch_size}:")
if verbose:
print(f"{iter} iterations of batchsize {batch_size}:")

base = signal

for _ in range(iter):
base = isht(sht(base))

err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
print(f"final relative error: {err.item()}")
if verbose:
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol)

@parameterized.expand(
[
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
[15, 30, 2, "ortho", "equiangular", 1e-5],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5],
[15, 30, 2, "four-pi", "equiangular", 1e-5],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5],
[15, 30, 2, "schmidt", "equiangular", 1e-5],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5],
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "four-pi", "equiangular", 1e-5, False],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
[12, 24, 2, "schmidt", "equiangular", 1e-5, False],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
[15, 30, 2, "ortho", "equiangular", 1e-5, False],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
[15, 30, 2, "four-pi", "equiangular", 1e-5, False],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
[15, 30, 2, "schmidt", "equiangular", 1e-5, False],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
]
)
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

if grid == "equiangular":
mmax = nlat // 2
Expand Down
Loading
Loading