diff --git a/CHANGELOG.md b/CHANGELOG.md index c72ac303c..8cd8e54c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,9 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Connection.web_editor()` to build link to the openEO backend in the openEO Web Editor - Add support for `log_level` in `create_job()` and `execute_job()` ([#704](https://github.com/Open-EO/openeo-python-client/issues/704)) - Add initial support for "geometry" dimension type in `CubeMetadata` ([#705](https://github.com/Open-EO/openeo-python-client/issues/705)) +- Add support for parameterized `bands` argument in `load_stac()` ### Changed +- Raise exception when providing empty bands array to `load_collection`/`load_stac` ([#424](https://github.com/Open-EO/openeo-python-client/issues/424), [Open-EO/openeo-processes#372](https://github.com/Open-EO/openeo-processes/issues/372)) + ### Removed ### Fixed diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 0a157cd45..5407c8839 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -1258,7 +1258,7 @@ def load_collection( collection_id: Union[str, Parameter], spatial_extent: Union[Dict[str, float], Parameter, None] = None, temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Union[None, List[str], Parameter] = None, + bands: Union[Iterable[str], Parameter, str, None] = None, properties: Union[ None, Dict[str, Union[str, PGNode, Callable]], List[CollectionProperty], CollectionProperty ] = None, @@ -1348,7 +1348,7 @@ def load_stac( url: str, spatial_extent: Union[Dict[str, float], Parameter, None] = None, temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Optional[List[str]] = None, + bands: Union[Iterable[str], Parameter, str, None] = None, properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None, ) -> DataCube: """ diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index a13abc0a4..d050e5306 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -145,7 +145,7 @@ def load_collection( connection: Optional[Connection] = None, spatial_extent: Union[Dict[str, float], Parameter, None] = None, temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Union[None, List[str], Parameter] = None, + bands: Union[Iterable[str], Parameter, str, None] = None, fetch_metadata: bool = True, properties: Union[ None, Dict[str, Union[str, PGNode, typing.Callable]], List[CollectionProperty], CollectionProperty @@ -198,10 +198,9 @@ def load_collection( metadata: Optional[CollectionMetadata] = ( connection.collection_metadata(collection_id) if connection and fetch_metadata else None ) - if bands: - if isinstance(bands, str): - bands = [bands] - elif isinstance(bands, Parameter): + if bands is not None: + bands = cls._get_bands(bands, process_id="load_collection") + if isinstance(bands, Parameter): metadata = None if metadata: bands = [b if isinstance(b, str) else metadata.band_dimension.band_name(b) for b in bands] @@ -272,7 +271,7 @@ def load_stac( url: str, spatial_extent: Union[Dict[str, float], Parameter, None] = None, temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Optional[List[str]] = None, + bands: Union[Iterable[str], Parameter, str, None] = None, properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None, connection: Optional[Connection] = None, ) -> DataCube: @@ -379,7 +378,8 @@ def load_stac( arguments["spatial_extent"] = spatial_extent if temporal_extent: arguments["temporal_extent"] = DataCube._get_temporal_extent(extent=temporal_extent) - if bands: + bands = cls._get_bands(bands, process_id="load_stac") + if bands is not None: arguments["bands"] = bands if properties: arguments["properties"] = { @@ -388,7 +388,7 @@ def load_stac( graph = PGNode("load_stac", arguments=arguments) try: metadata = metadata_from_stac(url) - if bands: + if isinstance(bands, list): # TODO: also apply spatial/temporal filters to metadata? metadata = metadata.filter_bands(band_names=bands) except Exception: @@ -429,6 +429,24 @@ def convertor(d: Any) -> Any: get_temporal_extent(*args, start_date=start_date, end_date=end_date, extent=extent, convertor=convertor) ) + @staticmethod + def _get_bands( + bands: Union[Iterable[str], Parameter, str, None], process_id: str + ) -> Union[None, List[str], Parameter]: + """Normalize band array for processes like load_collection, load_stac""" + if bands is None: + pass + elif isinstance(bands, str): + bands = [bands] + elif isinstance(bands, Parameter): + pass + else: + # Coerce to list + bands = list(bands) + if len(bands) == 0: + raise OpenEoClientException(f"Bands array should not be empty (process {process_id!r})") + return bands + @openeo_process def filter_temporal( self, diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 297fc7c1e..2e7406768 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -2308,6 +2308,34 @@ def test_load_collection_parameterized_bands(con100): } +@pytest.mark.parametrize( + "bands", + [ + ["B02", "B03"], + ("B02", "B03"), + iter(["B02", "B03"]), + ], +) +def test_load_collection_bands_iterable(con100, bands): + cube = con100.load_collection("S2", bands=bands) + assert get_download_graph(cube, drop_save_result=True) == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2", + "spatial_extent": None, + "temporal_extent": None, + "bands": ["B02", "B03"], + }, + }, + } + + +def test_load_collection_empty_bands_array(con100): + with pytest.raises(OpenEoClientException, match="Bands array should not be empty"): + _ = con100.load_collection("S2", bands=[]) + + @pytest.mark.parametrize( ["spatial_extent", "temporal_extent", "spatial_name", "temporal_name"], [ diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 3664edcc5..20c36987a 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -17,6 +17,7 @@ import openeo from openeo import BatchJob +from openeo.api.process import Parameter from openeo.capabilities import ApiVersionException from openeo.internal.graph_building import FlatGraphableMixin, PGNode from openeo.metadata import _PYSTAC_1_9_EXTENSION_INTERFACE, TemporalDimension @@ -2681,6 +2682,51 @@ def test_load_stac_band_filtering(self, con120, tmp_path): cube = con120.load_stac(str(stac_path), bands=["B03", "B02"]) assert cube.metadata.band_names == ["B03", "B02"] + @pytest.mark.parametrize( + "bands", + [ + ["B02", "B03"], + ("B02", "B03"), + iter(["B02", "B03"]), + ], + ) + def test_bands_iterable(self, con120, bands): + cube = con120.load_stac( + "https://provider.test/dataset", + bands=bands, + ) + assert cube.flat_graph() == { + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://provider.test/dataset", + "bands": ["B02", "B03"], + }, + "result": True, + } + } + + def test_bands_empty(self, con120): + with pytest.raises(OpenEoClientException, match="Bands array should not be empty"): + _ = con120.load_stac("https://provider.test/dataset", bands=[]) + + def test_bands_parameterized(self, con120): + bands = Parameter(name="my_bands", schema={"type": "array", "items": {"type": "string"}}) + cube = con120.load_stac( + "https://provider.test/dataset", + bands=bands, + ) + assert cube.flat_graph() == { + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://provider.test/dataset", + "bands": {"from_parameter": "my_bands"}, + }, + "result": True, + } + } + @pytest.mark.parametrize( "data",