Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add HRRR CONUS grid #19

Merged
merged 7 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
import torch

from earth2grid import base, healpix, latlon
from earth2grid import base, healpix, latlon, lcc
from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder

__all__ = [
"base",
"healpix",
"latlon",
"lcc",
"get_regridder",
"BilinearInterpolator",
"KNNS2Interpolator",
Expand All @@ -36,6 +37,8 @@ def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, lcc.LambertConformalConicGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
Expand Down
201 changes: 201 additions & 0 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator

try:
import pyvista as pv
except ImportError:
pv = None


simonbyrne marked this conversation as resolved.
Show resolved Hide resolved

class LambertConformalConicProjection:
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""

Args:
lat0: latitude of origin (degrees)
lon0: longitude of origin (degrees)
lat1: first standard parallel (degrees)
lat2: second standard parallel (degrees)
radius: radius of sphere (m)

"""

self.lon0 = lon0
self.lat0 = lat0
self.lat1 = lat1
self.lat2 = lat2
self.radius = radius


c1 = np.cos(np.deg2rad(lat1))
c2 = np.cos(np.deg2rad(lat2))
t1 = np.tan(np.pi / 4 + np.deg2rad(lat1) / 2)
t2 = np.tan(np.pi / 4 + np.deg2rad(lat2) / 2)

if np.abs(lat1 - lat2) < 1e-8:
self.n = np.sin(np.deg2rad(lat1))
else:
self.n = np.log(c1/c2) / np.log(t2/t1)

self.RF = radius * c1 * np.power(t1, self.n) / self.n
self.rho0 = self._rho(lat0)

def _rho(self, lat):
return self.RF / np.power(np.tan(np.pi / 4 + np.deg2rad(lat) / 2), self.n)

def _theta(self, lon):
"""
Angle of deviation (in radians) of the projected grid from the regular grid,
for a given longitude (in degrees).

To convert to U and V on the projected grid to easterly / northerly components:
UN = cos(theta) * U + sin(theta) * V
VN = - sin(theta) * U + cos(theta) * V
"""
# center about reference longitude
delta_lon = lon - self.lon0
delta_lon = delta_lon - np.round(delta_lon/360) * 360 # convert to [-180, 180]
return self.n * np.deg2rad(delta_lon)


def proj(self, lat, lon):
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the projected x,y from lat,lon.
"""
rho = self._rho(lat)
theta = self._theta(lon)

x = rho * np.sin(theta)
y = self.rho0 - rho * np.cos(theta)
return x, y

def inv(self, x, y):
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the lat,lon from the projected x,y.
"""
rho = np.hypot(x, self.rho0 - y)
theta = np.arctan2(x, self.rho0 - y)

lat = np.rad2deg(2 * np.arctan(np.power(self.RF/rho, 1/self.n))) - 90
lon = self.lon0 + np.rad2deg(theta / self.n)
return lat, lon

# Projection used by HRRR CONUS (Continental US) data
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(
lon0 = -97.5,
lat0 = 38.5,
lat1 = 38.5,
lat2 = 38.5,
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved
radius = 6371229.0
)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
y: range of y values

"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
return self.projection.inv(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.y), len(self.x))

def __getitem__(self, idxs):
yidxs, xidxs = idxs
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""
return _RegridFromLCC(self, lat, lon)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid

def hrrr_conus_grid(ix0 = 0, iy0 = 0, nx = 1799, ny = 1059):
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved
# coordinates of point in top-left corner
lat0 = 21.138123
lon0 = 237.280472
# grid length (m)
scale = 3000.0
# coordinates on projected space
x0, y0 = HRRR_CONUS_PROJECTION.proj(lat0, lon0)

x = [x0 + i * scale for i in range(ix0, ix0+nx)]
y = [y0 + i * scale for i in range(iy0, iy0+ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)

# Grid used by HRRR CONUS (Continental US) data
HRRR_CONUS_GRID = hrrr_conus_grid()


class _RegridFromLCC(torch.nn.Module):
"""Regrid from LambertConformalConicGrid to unstructured grid with bilinear interpolation"""

def __init__(self, src: LambertConformalConicGrid, lat: np.ndarray, lon: np.ndarray):
super().__init__()

x, y = src.projection.proj(lat, lon)

self.shape = lat.shape
self._bilinear = BilinearInterpolator(
x_coords = torch.from_numpy(src.x),
y_coords = torch.from_numpy(src.y),
x_query = torch.from_numpy(x.ravel()),
y_query = torch.from_numpy(y.ravel()))

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._bilinear(x)
return out.view(out.shape[:-1] + self.shape)
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved

66 changes: 66 additions & 0 deletions tests/test_lcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#%%
from earth2grid.lcc import HRRR_CONUS_GRID
import numpy as np
import torch
from pytest import approx
simonbyrne marked this conversation as resolved.
Show resolved Hide resolved

def test_grid_shape():
assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape
assert HRRR_CONUS_GRID.lon.shape == HRRR_CONUS_GRID.shape

lats = np.array([
[21.138123, 21.801926, 22.393631, 22.911015],
[23.636763, 24.328228, 24.944668, 25.48374 ],
[26.155672, 26.875362, 27.517046, 28.078257],
[28.69017 , 29.438608, 30.106009, 30.68978 ]])

lons = np.array([
[-122.71953 , -120.03195 , -117.304596, -114.54146 ],
[-123.491356, -120.72898 , -117.92319 , -115.07828 ],
[-124.310524, -121.469505, -118.58098 , -115.649574],
[-125.181404, -122.25762 , -119.28173 , -116.25871 ]])


def test_grid_vals():
assert HRRR_CONUS_GRID.lat[0:400:100,0:400:100] == approx(lats)
assert HRRR_CONUS_GRID.lon[0:400:100,0:400:100] == approx(lons)

def test_grid_slice():
slice_grid = HRRR_CONUS_GRID[0:400:100,0:400:100]
assert slice_grid.lat == approx(lats)
assert slice_grid.lon == approx(lons)

def test_regrid_1d():
src = HRRR_CONUS_GRID
dest_lat = np.linspace(25.0, 33.0, 10)
dest_lon = np.linspace(-123, -98, 10)
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))

def test_regrid_2d():
src = HRRR_CONUS_GRID
dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 10))
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))

Loading