Skip to content

Commit

Permalink
Merge pull request #191 from Open-EO/add-s2geo-splitter
Browse files Browse the repository at this point in the history
implemented first version of the S2GeoSplitter
  • Loading branch information
VincentVerelst authored Jan 7, 2025
2 parents f1e79a5 + 0839394 commit 41e95cd
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- `split_job_s2sphere` function to split jobs into S2cells of the s2sphere package. More info: http://s2geometry.io/. This job splitter recursively splits cells until the number of points in each cell is less than a given threshold.

### Changed
- `ouput_path_generator` in `GFMapJobManager.on_job_done` now requires `asset_id` as a keyword argument
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"netCDF4",
"scipy",
"rasterio",
"s2sphere==0.2.*",
]

[project.urls]
Expand Down
94 changes: 94 additions & 0 deletions src/openeo_gfmap/manager/job_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import geopandas as gpd
import h3
import requests
import s2sphere

from openeo_gfmap.manager import _log

Expand Down Expand Up @@ -196,3 +197,96 @@ def split_job_hex(
split_datasets.append(sub_gdf.reset_index(drop=True))

return split_datasets


def split_job_s2sphere(
gdf: gpd.GeoDataFrame, max_points=500, start_level=8
) -> List[gpd.GeoDataFrame]:
"""
EXPERIMENTAL
Split a GeoDataFrame into multiple groups based on the S2geometry cell ID of each geometry.
S2geometry is a library that provides a way to index and query spatial data. This function splits
the GeoDataFrame into groups based on the S2 cell ID of each geometry, based on it's centroid.
If a cell contains more points than max_points, it will be recursively split into
smaller cells until each cell contains at most max_points points.
More information on S2geometry can be found at https://s2geometry.io/
An overview of the S2 cell hierarchy can be found at https://s2geometry.io/resources/s2cell_statistics.html
:param gdf: GeoDataFrame containing points to split
:param max_points: Maximum number of points per group
:param start_level: Starting S2 cell level
:return: List of GeoDataFrames containing the split groups
"""

if "geometry" not in gdf.columns:
raise ValueError("The GeoDataFrame must contain a 'geometry' column.")

if gdf.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

# Store the original CRS of the GeoDataFrame and reproject to EPSG:3857
original_crs = gdf.crs
gdf = gdf.to_crs(epsg=3857)

# Add a centroid column to the GeoDataFrame and convert it to EPSG:4326
gdf["centroid"] = gdf.geometry.centroid

# Reproject the GeoDataFrame to its orginial CRS
gdf = gdf.to_crs(original_crs)

# Set the GeoDataFrame's geometry to the centroid column and reproject to EPSG:4326
gdf = gdf.set_geometry("centroid")
gdf = gdf.to_crs(epsg=4326)

# Create a dictionary to store points by their S2 cell ID
cell_dict = {}

# Iterate over each point in the GeoDataFrame
for idx, row in gdf.iterrows():
# Get the S2 cell ID for the point at a given level
cell_id = _get_s2cell_id(row.centroid, start_level)

if cell_id not in cell_dict:
cell_dict[cell_id] = []

cell_dict[cell_id].append(row)

result_groups = []

# Function to recursively split cells if they contain more points than max_points
def _split_s2cell(cell_id, points, current_level=start_level):
if len(points) <= max_points:
if len(points) > 0:
points = gpd.GeoDataFrame(
points, crs=original_crs, geometry="geometry"
).drop(columns=["centroid"])
points["s2sphere_cell_id"] = cell_id
points["s2sphere_cell_level"] = current_level
result_groups.append(gpd.GeoDataFrame(points))
else:
children = s2sphere.CellId(cell_id).children()
child_cells = {child.id(): [] for child in children}

for point in points:
child_cell_id = _get_s2cell_id(point.centroid, current_level + 1)
child_cells[child_cell_id].append(point)

for child_cell_id, child_points in child_cells.items():
_split_s2cell(child_cell_id, child_points, current_level + 1)

# Split cells that contain more points than max_points
for cell_id, points in cell_dict.items():
_split_s2cell(cell_id, points)

return result_groups


def _get_s2cell_id(point, level):
lat, lon = point.y, point.x
cell_id = s2sphere.CellId.from_lat_lng(
s2sphere.LatLng.from_degrees(lat, lon)
).parent(level)
return cell_id.id()
57 changes: 56 additions & 1 deletion tests/tests_unit/manager/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import geopandas as gpd
from shapely.geometry import Point, Polygon

from openeo_gfmap.manager.job_splitters import split_job_hex, split_job_s2grid
from openeo_gfmap.manager.job_splitters import (
split_job_hex,
split_job_s2grid,
split_job_s2sphere,
)


def test_split_job_s2grid():
Expand Down Expand Up @@ -91,3 +95,54 @@ def test_split_job_hex():
assert (
len(result[0]) == 3
), "The number of geometries in the first split should be 3."


def test_split_job_s2sphere():
# Create a mock GeoDataFrame with points
# The points are located in two different S2 tiles
data = {
"id": [1, 2, 3, 4, 5, 6, 7],
"geometry": [
Point(60.02, 4.57),
Point(58.34, 5.06),
Point(59.92, 3.37),
Point(59.93, 3.37),
Point(58.85, 4.90),
Point(58.77, 4.87),
Polygon(
[
(58.78, 4.88),
(58.78, 4.86),
(58.76, 4.86),
(58.76, 4.88),
(58.78, 4.88),
]
),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

# Define expected number of split groups
max_points = 3

# Call the function
result = split_job_s2sphere(polygons, max_points, start_level=8)

assert (
len(result) == 4
), "The number of GeoDataFrames returned should match the number of splits needed."

# Check if the geometries are preserved
for gdf in result:
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The original CRS should be preserved."
for _, geom in gdf.iterrows():
geom_type = geom.geometry.geom_type
original_type = polygons[polygons.id == geom.id].geometry.geom_type.values[
0
]
assert (
geom_type == original_type
), "Original geometries should be preserved."

0 comments on commit 41e95cd

Please sign in to comment.