Skip to content

Commit

Permalink
Improved computation of Morlet filter basis and switched to a Hann wi…
Browse files Browse the repository at this point in the history
…ndow.
  • Loading branch information
bonevbs committed Jan 20, 2025
1 parent 4129396 commit 6663841
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 47 deletions.
5 changes: 5 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Versioning

### v0.7.5
* New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function

### v0.7.4

* New filter basis normalization in DISCO convolutions
Expand Down
55 changes: 21 additions & 34 deletions notebooks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,79 +34,66 @@
import cartopy
import cartopy.crs as ccrs

def plot_sphere(data,
fig=None,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_latitude=0,
central_longitude=0,
lon=None,
lat=None,
**kwargs):

def plot_sphere(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_latitude=0, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None:
fig = plt.figure()

nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2*np.pi, nlon)
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat)
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)

proj = ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)

ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi

# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im)
plt.colorbar(im, extend="both")
plt.title(title, y=1.05)

return im

def plot_data(data,
fig=None,
cmap="RdBu",
title=None,
colorbar=False,
coastlines=False,
central_longitude=0,
lon=None,
lat=None,
**kwargs):

def plot_data(data, fig=None, cmap="RdBu", title=None, colorbar=False, coastlines=False, gridlines=False, central_longitude=0, lon=None, lat=None, **kwargs):
if fig == None:
fig = plt.figure()

nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
lon = np.linspace(0, 2*np.pi, nlon)
lon = np.linspace(0, 2 * np.pi, nlon)
if lat is None:
lat = np.linspace(np.pi/2., -np.pi/2., nlat)
lat = np.linspace(np.pi / 2.0, -np.pi / 2.0, nlat)
Lon, Lat = np.meshgrid(lon, lat)

proj = ccrs.PlateCarree(central_longitude=central_longitude)
# proj = ccrs.Mollweide(central_longitude=central_longitude)

ax = fig.add_subplot(projection=proj)
Lon = Lon*180/np.pi
Lat = Lat*180/np.pi
Lon = Lon * 180 / np.pi
Lat = Lat * 180 / np.pi

# contour data over the map.
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=False, **kwargs)
if coastlines:
ax.add_feature(cartopy.feature.COASTLINE, edgecolor='white', facecolor='none', linewidth=1.5)
ax.add_feature(cartopy.feature.COASTLINE, edgecolor="white", facecolor="none", linewidth=1.5)
if gridlines:
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1.5, color="gray", alpha=0.6, linestyle="--")
if colorbar:
plt.colorbar(im)
plt.colorbar(im, extend="both")
plt.title(title, y=1.05)

return im
return im
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

__version__ = "0.7.4"
__version__ = "0.7.5a"

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
Expand Down
7 changes: 7 additions & 0 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _normalize_convolution_tensor_s2(

# buffer to store intermediate values
vnorm = torch.zeros(kernel_size, nlat_out)
support = torch.zeros(kernel_size, nlat_out)

# loop through dimensions to compute the norms
for ik in range(kernel_size):
Expand All @@ -100,6 +101,10 @@ def _normalize_convolution_tensor_s2(
# vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])

# compute the support
support[ik, ilat] = torch.sum(q[iidx])


# loop over values and renormalize
for ik in range(kernel_size):
for ilat in range(nlat_out):
Expand All @@ -110,6 +115,8 @@ def _normalize_convolution_tensor_s2(
val = vnorm[ik, ilat]
elif basis_norm_mode == "mean":
val = vnorm[ik, :].mean()
elif basis_norm_mode == "support":
val = support[ik, ilat]
elif basis_norm_mode == "none":
val = 1.0
else:
Expand Down
28 changes: 16 additions & 12 deletions torch_harmonics/filter_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def kernel_size(self):
def gaussian_window(self, r: torch.Tensor, width: float = 1.0):
return 1 / (2 * math.pi * width**2) * torch.exp(-0.5 * r**2 / (width**2))

def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 0.25):
def hann_window(self, r: torch.Tensor, width: float = 1.0):
return torch.cos(0.5 * torch.pi * r / width) ** 2

def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float, width: float = 1.0):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
Expand All @@ -252,19 +255,20 @@ def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: flo
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))

# get x and y
x = r * torch.sin(phi) / r_cutoff
y = r * torch.cos(phi) / r_cutoff

harmonic = torch.where(nkernel % 2 == 1, torch.sin(torch.ceil(nkernel / 2) * math.pi * x / width), torch.cos(torch.ceil(nkernel / 2) * math.pi * x / width))
harmonic *= torch.where(mkernel % 2 == 1, torch.sin(torch.ceil(mkernel / 2) * math.pi * y / width), torch.cos(torch.ceil(mkernel / 2) * math.pi * y / width))
# get corresponding r, phi, x and y coordinates
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
phi = phi[iidx[:, 1], iidx[:, 2]]
x = r * torch.sin(phi)
y = r * torch.cos(phi)
n = nkernel[iidx[:, 0], 0, 0]
m = mkernel[iidx[:, 0], 0, 0]

# disk area
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff))
disk_area = 1.0
harmonic = torch.where(n % 2 == 1, torch.sin(torch.ceil(n / 2) * math.pi * x / width), torch.cos(torch.ceil(n / 2) * math.pi * x / width))
harmonic *= torch.where(m % 2 == 1, torch.sin(torch.ceil(m / 2) * math.pi * y / width), torch.cos(torch.ceil(m / 2) * math.pi * y / width))

# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self.gaussian_window(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
# computes the envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# vals = self.gaussian_window(r, width=width) * harmonic
vals = self.hann_window(r, width=width) * harmonic

return iidx, vals

Expand Down

0 comments on commit 6663841

Please sign in to comment.