diff --git a/mlrun/runtimes/function.py b/mlrun/runtimes/function.py index 112a42b7d8..22d60add43 100644 --- a/mlrun/runtimes/function.py +++ b/mlrun/runtimes/function.py @@ -528,6 +528,7 @@ def deploy( logger.info("Starting remote function deploy") data = db.remote_builder(self, False, builder_env=builder_env) self.status = data["data"].get("status") + self._update_credentials_from_remote_build(data["data"]) self._wait_for_function_deployment(db, verbose=verbose) # NOTE: on older mlrun versions & nuclio versions, function are exposed via NodePort @@ -991,6 +992,42 @@ def _resolve_invocation_url(self, path, force_external_address): else: return f"http://{self.status.address}/{path}" + def _update_credentials_from_remote_build(self, remote_data): + self.metadata.credentials = remote_data.get("metadata", {}).get( + "credentials", {} + ) + + credentials_env_var_names = ["V3IO_ACCESS_KEY", "MLRUN_AUTH_SESSION"] + new_env = [] + + # the env vars in the local spec and remote spec are in the format of a list of dicts + # e.g.: + # env = [ + # { + # "name": "V3IO_ACCESS_KEY", + # "value": "some-value" + # }, + # ... + # ] + # remove existing credentials env vars + for env in self.spec.env: + if isinstance(env, dict): + env_name = env["name"] + elif isinstance(env, client.V1EnvVar): + env_name = env.name + else: + continue + + if env_name not in credentials_env_var_names: + new_env.append(env) + + # add credentials env vars from remote build + for remote_env in remote_data.get("spec", {}).get("env", []): + if remote_env.get("name") in credentials_env_var_names: + new_env.append(remote_env) + + self.spec.env = new_env + def parse_logs(logs): logs = json.loads(logs) diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index 41c9e01ce8..de76fcf520 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -205,7 +205,13 @@ def remote_builder( state="ready", nuclio_name="test-nuclio-name", ) - return {"data": {"status": status.to_dict()}} + return { + "data": { + "status": status.to_dict(), + "metadata": self._function.get("metadata"), + "spec": self._function.get("spec"), + } + } def get_builder_status( self, diff --git a/tests/runtimes/test_function.py b/tests/runtimes/test_function.py index c99036ff24..59869ed5ed 100644 --- a/tests/runtimes/test_function.py +++ b/tests/runtimes/test_function.py @@ -1,6 +1,7 @@ import pathlib import sys +import pytest from deepdiff import DeepDiff import mlrun @@ -161,3 +162,28 @@ def test_resolve_git_reference_from_source(): ] for source, expected in cases: assert expected == _resolve_git_reference_from_source(source) + + +@pytest.mark.parametrize("function_kind", ["serving", "remote"]) +def test_update_credentials_from_remote_build(function_kind): + secret_name = "secret-name" + remote_data = { + "metadata": {"credentials": {"access_key": secret_name}}, + "spec": { + "env": [ + {"name": "V3IO_ACCESS_KEY", "value": secret_name}, + {"name": "MLRUN_AUTH_SESSION", "value": secret_name}, + ], + }, + } + + function = mlrun.new_function("tst", kind=function_kind) + function.metadata.credentials.access_key = "access_key" + function.spec.env = [ + {"name": "V3IO_ACCESS_KEY", "value": "access_key"}, + {"name": "MLRUN_AUTH_SESSION", "value": "access_key"}, + ] + function._update_credentials_from_remote_build(remote_data) + + assert function.metadata.credentials.access_key == secret_name + assert function.spec.env == remote_data["spec"]["env"]