Skip to content

Commit

Permalink
Disallow empty bands array in load_collection/load_stac
Browse files Browse the repository at this point in the history
and add support for parameterized bands in `load_stac`

refs: #424, Open-EO/openeo-processes#372
  • Loading branch information
soxofaan committed Jan 16, 2025
1 parent ddd2185 commit a907e38
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `Connection.web_editor()` to build link to the openEO backend in the openEO Web Editor
- Add support for `log_level` in `create_job()` and `execute_job()` ([#704](https://github.com/Open-EO/openeo-python-client/issues/704))
- Add initial support for "geometry" dimension type in `CubeMetadata` ([#705](https://github.com/Open-EO/openeo-python-client/issues/705))
- Add support for parameterized `bands` argument in `load_stac()`

### Changed

- Raise exception when providing empty bands array to `load_collection`/`load_stac` ([#424](https://github.com/Open-EO/openeo-python-client/issues/424), [Open-EO/openeo-processes#372](https://github.com/Open-EO/openeo-processes/issues/372))

### Removed

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def load_collection(
collection_id: Union[str, Parameter],
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
bands: Union[None, List[str], Parameter] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
properties: Union[
None, Dict[str, Union[str, PGNode, Callable]], List[CollectionProperty], CollectionProperty
] = None,
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def load_stac(
url: str,
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
bands: Optional[List[str]] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None,
) -> DataCube:
"""
Expand Down
34 changes: 26 additions & 8 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_collection(
connection: Optional[Connection] = None,
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
bands: Union[None, List[str], Parameter] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
fetch_metadata: bool = True,
properties: Union[
None, Dict[str, Union[str, PGNode, typing.Callable]], List[CollectionProperty], CollectionProperty
Expand Down Expand Up @@ -198,10 +198,9 @@ def load_collection(
metadata: Optional[CollectionMetadata] = (
connection.collection_metadata(collection_id) if connection and fetch_metadata else None
)
if bands:
if isinstance(bands, str):
bands = [bands]
elif isinstance(bands, Parameter):
if bands is not None:
bands = cls._get_bands(bands, process_id="load_collection")
if isinstance(bands, Parameter):
metadata = None
if metadata:
bands = [b if isinstance(b, str) else metadata.band_dimension.band_name(b) for b in bands]
Expand Down Expand Up @@ -272,7 +271,7 @@ def load_stac(
url: str,
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
bands: Optional[List[str]] = None,
bands: Union[Iterable[str], Parameter, str, None] = None,
properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None,
connection: Optional[Connection] = None,
) -> DataCube:
Expand Down Expand Up @@ -379,7 +378,8 @@ def load_stac(
arguments["spatial_extent"] = spatial_extent
if temporal_extent:
arguments["temporal_extent"] = DataCube._get_temporal_extent(extent=temporal_extent)
if bands:
bands = cls._get_bands(bands, process_id="load_stac")
if bands is not None:
arguments["bands"] = bands
if properties:
arguments["properties"] = {
Expand All @@ -388,7 +388,7 @@ def load_stac(
graph = PGNode("load_stac", arguments=arguments)
try:
metadata = metadata_from_stac(url)
if bands:
if isinstance(bands, list):
# TODO: also apply spatial/temporal filters to metadata?
metadata = metadata.filter_bands(band_names=bands)
except Exception:
Expand Down Expand Up @@ -429,6 +429,24 @@ def convertor(d: Any) -> Any:
get_temporal_extent(*args, start_date=start_date, end_date=end_date, extent=extent, convertor=convertor)
)

@staticmethod
def _get_bands(
bands: Union[Iterable[str], Parameter, str, None], process_id: str
) -> Union[None, List[str], Parameter]:
"""Normalize band array for processes like load_collection, load_stac"""
if bands is None:
pass
elif isinstance(bands, str):
bands = [bands]
elif isinstance(bands, Parameter):
pass
else:
# Coerce to list
bands = list(bands)
if len(bands) == 0:
raise OpenEoClientException(f"Bands array should not be empty (process {process_id!r})")
return bands

@openeo_process
def filter_temporal(
self,
Expand Down
28 changes: 28 additions & 0 deletions tests/rest/datacube/test_datacube100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,34 @@ def test_load_collection_parameterized_bands(con100):
}


@pytest.mark.parametrize(
"bands",
[
["B02", "B03"],
("B02", "B03"),
iter(["B02", "B03"]),
],
)
def test_load_collection_bands_iterable(con100, bands):
cube = con100.load_collection("S2", bands=bands)
assert get_download_graph(cube, drop_save_result=True) == {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {
"id": "S2",
"spatial_extent": None,
"temporal_extent": None,
"bands": ["B02", "B03"],
},
},
}


def test_load_collection_empty_bands_array(con100):
with pytest.raises(OpenEoClientException, match="Bands array should not be empty"):
_ = con100.load_collection("S2", bands=[])


@pytest.mark.parametrize(
["spatial_extent", "temporal_extent", "spatial_name", "temporal_name"],
[
Expand Down
46 changes: 46 additions & 0 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import openeo
from openeo import BatchJob
from openeo.api.process import Parameter
from openeo.capabilities import ApiVersionException
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension
Expand Down Expand Up @@ -2681,6 +2682,51 @@ def test_load_stac_band_filtering(self, con120, tmp_path):
cube = con120.load_stac(str(stac_path), bands=["B03", "B02"])
assert cube.metadata.band_names == ["B03", "B02"]

@pytest.mark.parametrize(
"bands",
[
["B02", "B03"],
("B02", "B03"),
iter(["B02", "B03"]),
],
)
def test_bands_iterable(self, con120, bands):
cube = con120.load_stac(
"https://provider.test/dataset",
bands=bands,
)
assert cube.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {
"url": "https://provider.test/dataset",
"bands": ["B02", "B03"],
},
"result": True,
}
}

def test_bands_empty(self, con120):
with pytest.raises(OpenEoClientException, match="Bands array should not be empty"):
_ = con120.load_stac("https://provider.test/dataset", bands=[])

def test_bands_parameterized(self, con120):
bands = Parameter(name="my_bands", schema={"type": "array", "items": {"type": "string"}})
cube = con120.load_stac(
"https://provider.test/dataset",
bands=bands,
)
assert cube.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {
"url": "https://provider.test/dataset",
"bands": {"from_parameter": "my_bands"},
},
"result": True,
}
}


@pytest.mark.parametrize(
"data",
Expand Down

0 comments on commit a907e38

Please sign in to comment.