Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resample cube update #300

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 46 additions & 67 deletions openeo_processes_dask/process_implementations/cubes/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ def resample_spatial(
):
"""Resamples the spatial dimensions (x,y) of the data cube to a specified resolution and/or warps the data cube to the target projection. At least resolution or projection must be specified."""

if data.openeo.y_dim is None or data.openeo.x_dim is None:
raise DimensionMissing(f"Spatial dimension missing for dataset: {data} ")

methods_list = [
"near",
"bilinear",
"cubic",
"cubicspline",
"lanczos",
"average",
"mode",
"max",
"min",
"med",
"q1",
"q3",
]

if method not in methods_list:
raise Exception(
f'Selected resampling method "{method}" is not available! Please select one of '
f"[{', '.join(methods_list)}]"
)

# Assert resampling method is correct.
if method == "near":
method = "nearest"
Expand All @@ -53,18 +77,9 @@ def resample_spatial(
f"[{', '.join(resample_methods_list)}]"
)

dims = list(data.dims)
dim_order = data.dims

known_dims = []
if len(data.openeo.band_dims) > 0:
known_dims.append(data.openeo.band_dims[0])
if len(data.openeo.temporal_dims) > 0:
known_dims.append(data.openeo.temporal_dims[0])
known_dims += [data.openeo.y_dim, data.openeo.x_dim]

other_dims = [dim for dim in dims if dim not in known_dims]

data_cp = data.transpose(*other_dims, *known_dims)
data_cp = data.transpose(..., data.openeo.y_dim, data.openeo.x_dim)

if projection is None:
projection = data_cp.rio.crs
Expand All @@ -89,6 +104,8 @@ def resample_spatial(
if reprojected.openeo.y_dim != data.openeo.y_dim:
reprojected = reprojected.rename({reprojected.openeo.y_dim: data.openeo.y_dim})

reprojected = reprojected.transpose(*dim_order)

reprojected.attrs["crs"] = data_cp.rio.crs

return reprojected
Expand All @@ -97,67 +114,29 @@ def resample_spatial(
def resample_cube_spatial(
data: RasterCube, target: RasterCube, method="near", options=None
) -> RasterCube:
methods_list = [
"near",
"bilinear",
"cubic",
"cubicspline",
"lanczos",
"average",
"mode",
"max",
"min",
"med",
"q1",
"q3",
]

if (
data.openeo.y_dim is None
or data.openeo.x_dim is None
or target.openeo.y_dim is None
or target.openeo.x_dim is None
):
if target.openeo.y_dim is None or target.openeo.x_dim is None:
raise DimensionMissing(
f"Spatial dimension missing from data or target. Available dimensions for data: {data.dims} for target: {target.dims}"
f"Spatial dimension missing for target dataset: {target} "
)

# ODC reproject requires y to be before x
required_dim_order = (..., data.openeo.y_dim, data.openeo.x_dim)

data_reordered = data.transpose(*required_dim_order, missing_dims="ignore")
target_reordered = target.transpose(*required_dim_order, missing_dims="ignore")

if method == "near":
method = "nearest"

elif method not in methods_list:
raise Exception(
f'Selected resampling method "{method}" is not available! Please select one of '
f"[{', '.join(methods_list)}]"
)

resampled_data = data_reordered.odc.reproject(
target_reordered.odc.geobox, resampling=method
target_resolution, target_crs = None, None
if hasattr(target, "rio"):
if hasattr(target.rio, "resolution"):
if type(target.rio.resolution()) in [tuple, list]:
target_resolution = target.rio.resolution()[0]
else:
target_resolution = target.rio.resolution()
if hasattr(target.rio, "crs"):
target_crs = target.rio.crs
if not target_crs:
raise OpenEOException(f"Projection not found in target dataset: {target} ")
if not target_resolution:
raise OpenEOException(f"Resolution not found in target dataset: {target} ")

resampled_data = resample_spatial(
data=data, projection=target_crs, resolution=target_resolution, method=method
)

resampled_data.rio.write_crs(target_reordered.rio.crs, inplace=True)

try:
# odc.reproject renames the coordinates according to the geobox, this undoes that.
resampled_data = resampled_data.rename(
{"longitude": data.openeo.x_dim, "latitude": data.openeo.y_dim}
)
except ValueError:
pass

# Order axes back to how they were before
resampled_data = resampled_data.transpose(*data.dims)

# Ensure that attrs except crs are copied over
for k, v in data.attrs.items():
if k.lower() != "crs":
resampled_data.attrs[k] = v
return resampled_data


Expand Down
42 changes: 42 additions & 0 deletions tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,48 @@ def test_resample_cube_spatial(
assert output_cube.odc.spatial_dims == ("y", "x")


@pytest.mark.parametrize(
"output_crs",
[
3587,
"32633",
"+proj=aeqd +lat_0=53 +lon_0=24 +x_0=5837287.81977 +y_0=2121415.69617 +datum=WGS84 +units=m +no_defs",
],
)
@pytest.mark.parametrize("output_res", [5, 30, 60])
@pytest.mark.parametrize("size", [(30, 30, 20, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_resample_cube_spatial_small(
output_crs, output_res, temporal_interval, bounding_box, random_raster_data
):
"""Test to ensure resolution gets changed correctly."""
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)

resampled_cube = resample_spatial(
data=input_cube, projection=output_crs, resolution=output_res
)

output_cube = resample_cube_spatial(
data=input_cube, target=resampled_cube[10:60, 20:150, :, :], method="average"
)

general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
expected_dims=input_cube.dims,
verify_attrs=False,
verify_crs=False,
)

assert list(output_cube.shape) == list(resampled_cube.shape)


@pytest.mark.parametrize("size", [(6, 5, 30, 4)])
@pytest.mark.parametrize("dtype", [np.float64])
@pytest.mark.parametrize(
Expand Down
Loading