Skip to content

Commit

Permalink
Relax runtime validation in UDF._guess_runtime
Browse files Browse the repository at this point in the history
and improve type annotations along the way
related to #470 and Open-EO/openeo-api#510
  • Loading branch information
soxofaan committed Sep 8, 2023
1 parent ba27981 commit 6cc64cb
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions openeo/rest/_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
def __repr__(self):
return f"<{type(self).__name__} runtime={self._runtime!r} code={str_truncate(self.code, width=200)!r}>"

def get_runtime(self, connection: Connection) -> str:
def get_runtime(self, connection: Optional[Connection] = None) -> str:
return self._runtime or self._guess_runtime(connection=connection)

@classmethod
Expand Down Expand Up @@ -221,7 +221,7 @@ def from_url(
code=code, runtime=runtime, version=version, context=context, _source=url
)

def _guess_runtime(self, connection: Connection) -> str:
def _guess_runtime(self, connection: Optional[Connection] = None) -> str:
"""Guess UDF runtime from UDF source (path) or source code."""
# First, guess UDF language
language = None
Expand All @@ -240,22 +240,22 @@ def _guess_runtime(self, connection: Connection) -> str:
# TODO: detection heuristics for R and other languages?
if not language:
raise OpenEoClientException("Failed to detect language of UDF code.")
# Find runtime for language
runtimes = {k.lower(): k for k in connection.list_udf_runtimes().keys()}
if language.lower() in runtimes:
return runtimes[language.lower()]
else:
raise OpenEoClientException(
f"Failed to match UDF language {language!r} with a runtime ({runtimes})"
)
runtime = language
if connection:
# Some additional best-effort validation/normalization of the runtime
# TODO: this just does some case-normalization, just drop that all together to eliminate
# the dependency on a connection object. See https://github.com/Open-EO/openeo-api/issues/510
runtimes = {k.lower(): k for k in connection.list_udf_runtimes().keys()}
runtime = runtimes.get(runtime.lower(), runtime)
return runtime

def _guess_runtime_from_suffix(self, suffix: str) -> Union[str]:
return {
".py": "Python",
".r": "R",
}.get(suffix.lower())

def get_run_udf_callback(self, connection: Connection, data_parameter: str = "data") -> PGNode:
def get_run_udf_callback(self, connection: Optional[Connection] = None, data_parameter: str = "data") -> PGNode:
"""
For internal use: construct `run_udf` node to be used as callback in `apply`, `reduce_dimension`, ...
"""
Expand All @@ -272,18 +272,20 @@ def get_run_udf_callback(self, connection: Connection, data_parameter: str = "da
def build_child_callback(
process: Union[str, PGNode, typing.Callable, UDF],
parent_parameters: List[str],
connection: Optional["openeo.Connection"] = None,
connection: Optional[Connection] = None,
) -> dict:
"""
Build a "callback" process: a user defined process that is used by another process (such
as `apply`, `apply_dimension`, `reduce`, ....)
:param process: process id string, PGNode or callable that uses the ProcessBuilder mechanism to build a process
:param parent_parameters: list of parameter names defined for child process
:param connection: optional connection object to improve runtime validation for UDFs
:return:
"""
# TODO: move this to more generic process graph building utility module
# TODO: autodetect the parameters defined by parent process?
# TODO: eliminate need for connection object (also see `UDF._guess_runtime`)
if isinstance(process, PGNode):
# Assume this is already a valid callback process
pg = process
Expand Down

0 comments on commit 6cc64cb

Please sign in to comment.