diff --git a/mlrun/db/filedb.py b/mlrun/db/filedb.py index 17b21b7828e..1716c9d724b 100644 --- a/mlrun/db/filedb.py +++ b/mlrun/db/filedb.py @@ -111,7 +111,6 @@ def store_run(self, struct, uid, project="", iter=0): def update_run(self, updates: dict, uid, project="", iter=0): run = self.read_run(uid, project, iter=iter) - # TODO: Should we raise if run not found? if run and updates: for key, val in updates.items(): update_in(run, key, val) diff --git a/mlrun/execution.py b/mlrun/execution.py index c94dc3e2b80..aa5381afbe2 100644 --- a/mlrun/execution.py +++ b/mlrun/execution.py @@ -19,6 +19,7 @@ from typing import List, Union import numpy as np +import yaml import mlrun from mlrun.artifacts import ModelArtifact @@ -93,6 +94,8 @@ def __init__(self, autocommit=False, tmp="", log_stream=None): self._outputs = [] self._results = {} + # tracks the execution state, completion of runs is not decided by the execution + # as there may be multiple executions for a single run (e.g mpi) self._state = "created" self._error = None self._commit = "" @@ -113,7 +116,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): if exc_value: self.set_state(error=exc_value, commit=False) - self.commit() + self.commit(completed=True) def get_child_context(self, with_parent_params=False, **params): """get child context (iteration) @@ -259,7 +262,7 @@ def from_dict( host=None, log_stream=None, is_api=False, - update_db=True, + store_run=True, ): """create execution context from dict""" @@ -314,8 +317,8 @@ def from_dict( if start: self._start_time = start self._state = "running" - if update_db: - self._update_db(commit=True) + if store_run: + self.store_run() return self @property @@ -330,6 +333,11 @@ def tag(self): """run tag (uid or workflow id if exists)""" return self._labels.get("workflow") or self._uid + @property + def state(self): + """execution state""" + return self._state + @property def iteration(self): """child iteration index, for hyper parameters""" @@ -445,7 +453,7 @@ def get_param(self, key: str, default=None): if key not in self._parameters: self._parameters[key] = default if default: - self._update_db() + self._update_run() return default return self._parameters[key] @@ -520,7 +528,7 @@ def log_result(self, key: str, value, commit=False): :param commit: commit (write to DB now vs wait for the end of the run) """ self._results[str(key)] = _cast_result(value) - self._update_db(commit=commit) + self._update_run(commit=commit) def log_results(self, results: dict, commit=False): """log a set of scalar result values @@ -539,7 +547,7 @@ def log_results(self, results: dict, commit=False): for p in results.keys(): self._results[str(p)] = _cast_result(results[p]) - self._update_db(commit=commit) + self._update_run(commit=commit) def log_iteration_results(self, best, summary: list, task: dict, commit=False): """Reserved for internal use""" @@ -566,7 +574,7 @@ def log_iteration_results(self, best, summary: list, task: dict, commit=False): if summary is not None: self._iteration_results = summary if commit: - self._update_db(commit=True) + self._update_run(commit=True) def log_metric(self, key: str, value, timestamp=None, labels=None): """TBD, log a real-time time-series metric""" @@ -648,7 +656,7 @@ def log_artifact( format=format, **kwargs, ) - self._update_db() + self._update_run() return item def log_dataset( @@ -727,7 +735,7 @@ def log_dataset( db_key=db_key, labels=labels, ) - self._update_db() + self._update_run() return item def log_model( @@ -829,7 +837,7 @@ def log_model( db_key=db_key, labels=labels, ) - self._update_db() + self._update_run() return item def get_cached_artifact(self, key): @@ -840,13 +848,16 @@ def update_artifact(self, artifact_object): """update an artifact object in the cache and the DB""" self._artifacts_manager.update_artifact(self, artifact_object) - def commit(self, message: str = "", completed=True): + def commit(self, message: str = "", completed=False): """save run state and optionally add a commit message :param message: commit message to save in the run :param completed: mark run as completed """ - completed = completed and self._state == "running" + # changing state to completed is allowed only when the execution is in running state + if self._state != "running": + completed = False + if message: self._annotations["message"] = message if completed: @@ -855,22 +866,28 @@ def commit(self, message: str = "", completed=True): if self._parent: self._parent.update_child_iterations() self._parent._last_update = now_date() - self._parent._update_db(commit=True, message=message) + self._parent._update_run(commit=True, message=message) if self._children: self.update_child_iterations(commit_children=True, completed=completed) self._last_update = now_date() - self._update_db(commit=True, message=message) + self._update_run(commit=True, message=message) if completed and not self.iteration: mlrun.runtimes.utils.global_context.set(None) - def set_state(self, state: str = None, error: str = None, commit=True): - """modify and store the run state or mark an error + def set_state(self, execution_state: str = None, error: str = None, commit=True): + """ + Modify and store the execution state or mark an error and update the run state accordingly. + This method allows to set the run state to 'completed' in the DB which is discouraged. + Completion of runs should be decided externally to the execution context. - :param state: set run state - :param error: error message (if exist will set the state to error) - :param commit: will immediately update the state in the DB + :param execution_state: set execution state + :param error: error message (if exist will set the state to error) + :param commit: will immediately update the state in the DB """ + # TODO: The execution context should not set the run state to completed. + # Create a separate state for the execution in the run object. + updates = {"status.last_update": now_date().isoformat()} if error: @@ -878,9 +895,13 @@ def set_state(self, state: str = None, error: str = None, commit=True): self._error = str(error) updates["status.state"] = "error" updates["status.error"] = error - elif state and state != self._state and self._state != "error": - self._state = state - updates["status.state"] = state + elif ( + execution_state + and execution_state != self._state + and self._state != "error" + ): + self._state = execution_state + updates["status.state"] = execution_state self._last_update = now_date() if self._rundb and commit: @@ -900,9 +921,9 @@ def set_hostname(self, host: str): def to_dict(self): """convert the run context to a dictionary""" - def set_if_valid(struct, key, val): + def set_if_not_none(_struct, key, val): if val: - struct[key] = val + _struct[key] = val struct = { "kind": "run", @@ -924,26 +945,52 @@ def set_if_valid(struct, key, val): run_keys.inputs: {k: v.artifact_url for k, v in self._inputs.items()}, }, "status": { - "state": self._state, "results": self._results, "start_time": to_date_str(self._start_time), "last_update": to_date_str(self._last_update), }, } + # completion of runs is not decided by the execution as there may be + # multiple executions for a single run (e.g. mpi) + if self._state != "completed": + struct["status"]["state"] = self._state + if not self._iteration: struct["spec"]["hyperparams"] = self._hyperparams struct["spec"]["hyper_param_options"] = self._hyper_param_options.to_dict() - set_if_valid(struct["status"], "error", self._error) - set_if_valid(struct["status"], "commit", self._commit) + set_if_not_none(struct["status"], "error", self._error) + set_if_not_none(struct["status"], "commit", self._commit) + set_if_not_none(struct["status"], "iterations", self._iteration_results) - if self._iteration_results: - struct["status"]["iterations"] = self._iteration_results struct["status"][run_keys.artifacts] = self._artifacts_manager.artifact_list() self._data_stores.to_dict(struct["spec"]) return struct + def _get_updates(self): + def set_if_not_none(_struct, key, val): + if val: + _struct[key] = val + + struct = { + "status.results": self._results, + "status.start_time": to_date_str(self._start_time), + "status.last_update": to_date_str(self._last_update), + } + + # completion of runs is not decided by the execution as there may be + # multiple executions for a single run (e.g. mpi) + if self._state != "completed": + struct["status.state"] = self._state + + set_if_not_none(struct, "status.error", self._error) + set_if_not_none(struct, "status.commit", self._commit) + set_if_not_none(struct, "status.iterations", self._iteration_results) + + struct[f"status.{run_keys.artifacts}"] = self._artifacts_manager.artifact_list() + return struct + def to_yaml(self): """convert the run context to a yaml buffer""" return dict_to_yaml(self.to_dict()) @@ -952,21 +999,55 @@ def to_json(self): """convert the run context to a json buffer""" return dict_to_json(self.to_dict()) - def _update_db(self, commit=False, message=""): - self.last_update = now_date() - if self._tmpfile: - data = self.to_json() - with open(self._tmpfile, "w") as fp: - fp.write(data) - fp.close() + def store_run(self): + self._write_tmpfile() + if self._rundb: + self._rundb.store_run( + self.to_dict(), self._uid, self.project, iter=self._iteration + ) + def _update_run(self, commit=False, message=""): + """ + update the required fields in the run object (using mlrun.utils.helpers.update_in) + instead of overwriting existing + """ + self._merge_tmpfile() if commit or self._autocommit: self._commit = message if self._rundb: - self._rundb.store_run( - self.to_dict(), self._uid, self.project, iter=self._iteration + self._rundb.update_run( + self._get_updates(), self._uid, self.project, iter=self._iteration ) + def _merge_tmpfile(self): + if not self._tmpfile: + return + + loaded_run = self._read_tmpfile() + dict_run = self.to_dict() + if loaded_run: + for key, val in dict_run.items(): + update_in(loaded_run, key, val) + else: + loaded_run = dict_run + + self._write_tmpfile(json=dict_to_json(loaded_run)) + + def _read_tmpfile(self): + if self._tmpfile: + with open(self._tmpfile) as fp: + return yaml.safe_load(fp) + + return None + + def _write_tmpfile(self, json=None): + self.last_update = now_date() + if self._tmpfile: + data = json or self.to_json() + with open(self._tmpfile, "w") as fp: + fp.write(data) + fp.close() + def _cast_result(value): if isinstance(value, (int, str, float)): diff --git a/mlrun/runtimes/base.py b/mlrun/runtimes/base.py index 1eee45ebcce..635698e1ed8 100644 --- a/mlrun/runtimes/base.py +++ b/mlrun/runtimes/base.py @@ -427,7 +427,7 @@ def run( db, autocommit=False, is_api=self._is_api_server, - update_db=False, + store_run=False, ) self._verify_run_params(run.spec.parameters) @@ -442,8 +442,8 @@ def run( for task in tasks: self._verify_run_params(task.spec.parameters) - # post verifications, update execution in db and run pre run hooks - execution.commit(completed=False) + # post verifications, store execution in db and run pre run hooks + execution.store_run() self._pre_run(run, execution) # hook for runtime specific prep last_err = None @@ -459,6 +459,7 @@ def run( results = runner(task_generator, execution, run) results_to_iter(results, run, execution) result = execution.to_dict() + result = self._update_run_state(result, task=run) else: # single run @@ -718,11 +719,11 @@ def _enrich_run( ) return runspec - def _submit_job(self, runspec, schedule, db, watch): + def _submit_job(self, run: RunObject, schedule, db, watch): if self._secrets: - runspec.spec.secret_sources = self._secrets.to_serial() + run.spec.secret_sources = self._secrets.to_serial() try: - resp = db.submit_job(runspec, schedule=schedule) + resp = db.submit_job(run, schedule=schedule) if schedule: logger.info(f"task scheduled, {resp}") return @@ -737,8 +738,8 @@ def _submit_job(self, runspec, schedule, db, watch): # if we got a schedule no reason to do post_run stuff (it purposed to update the run status with error, # but there's no run in case of schedule) if not schedule: - result = self._update_run_state(task=runspec, err=err_to_str(err)) - return self._wrap_run_result(result, runspec, schedule=schedule, err=err) + result = self._update_run_state(task=run, err=err_to_str(err)) + return self._wrap_run_result(result, run, schedule=schedule, err=err) if resp: txt = get_in(resp, "status.status_text") @@ -761,19 +762,19 @@ def _submit_job(self, runspec, schedule, db, watch): config.httpdb.logs.pipelines.pull_state.pull_logs_interval ) - runspec.wait_for_completion( + run.wait_for_completion( show_logs=True, sleep=state_interval, logs_interval=logs_interval, raise_on_failure=False, ) - resp = self._get_db_run(runspec) + resp = self._get_db_run(run) elif watch or self.kfp: - runspec.logs(True, self._get_db()) - resp = self._get_db_run(runspec) + run.logs(True, self._get_db()) + resp = self._get_db_run(run) - return self._wrap_run_result(resp, runspec, schedule=schedule) + return self._wrap_run_result(resp, run, schedule=schedule) @staticmethod def _handle_submit_job_http_error(error: requests.HTTPError): @@ -941,20 +942,32 @@ def _update_run_state( updates = None last_state = get_in(resp, "status.state", "") + kind = get_in(resp, "metadata.labels.kind", "") if last_state == "error" or err: - updates = {"status.last_update": now_date().isoformat()} - updates["status.state"] = "error" + updates = { + "status.last_update": now_date().isoformat(), + "status.state": "error", + } update_in(resp, "status.state", "error") if err: update_in(resp, "status.error", err_to_str(err)) err = get_in(resp, "status.error") if err: updates["status.error"] = err_to_str(err) - elif not was_none and last_state != "completed": - updates = {"status.last_update": now_date().isoformat()} - updates["status.state"] = "completed" - update_in(resp, "status.state", "completed") + elif not was_none and last_state != "completed": + try: + runtime_handler = mlrun.runtimes.get_runtime_handler(kind) + updates = runtime_handler._get_run_completion_updates(resp) + except KeyError: + updates = BaseRuntimeHandler._get_run_completion_updates(resp) + + logger.debug( + "Run updates", + kind=kind, + last_state=last_state, + updates=updates, + ) if self._get_db() and updates: project = get_in(resp, "metadata.project") uid = get_in(resp, "metadata.uid") @@ -1715,6 +1728,19 @@ def _get_default_label_selector(self) -> str: return f"mlrun/class={class_values[0]}" return f"mlrun/class in ({', '.join(class_values)})" + @staticmethod + def _get_run_completion_updates(run: dict) -> dict: + """ + Get the required updates for the run object when it's completed and update the run object state + Override this if the run completion is not resolved by a single execution + """ + updates = { + "status.last_update": now_date().isoformat(), + "status.state": "completed", + } + update_in(run, "status.state", "completed") + return updates + @staticmethod def _get_crd_info() -> Tuple[str, str, str]: """ diff --git a/mlrun/runtimes/local.py b/mlrun/runtimes/local.py index f5c4d5dbfb2..476286f41f6 100644 --- a/mlrun/runtimes/local.py +++ b/mlrun/runtimes/local.py @@ -456,7 +456,9 @@ def exec_from_params(handler, runobj: RunObject, context: MLClientCtx, cwd=None) context.set_logger_stream(sys.stdout) if val: context.log_result("return", val) - context.commit() + + # completion will be ignored if error is set + context.commit(completed=True) logger.set_logger_level(old_level) return stdout.buf.getvalue(), err diff --git a/mlrun/runtimes/mpijob/abstract.py b/mlrun/runtimes/mpijob/abstract.py index 2f1aaa85e5f..ec54df540e4 100644 --- a/mlrun/runtimes/mpijob/abstract.py +++ b/mlrun/runtimes/mpijob/abstract.py @@ -24,7 +24,7 @@ from mlrun.model import RunObject from mlrun.runtimes.kubejob import KubejobRuntime from mlrun.runtimes.pod import KubeResourceSpec -from mlrun.runtimes.utils import AsyncLogWriter, RunError +from mlrun.runtimes.utils import RunError from mlrun.utils import get_in, logger @@ -174,30 +174,15 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): launcher, _ = self._get_launcher(meta.name, meta.namespace) execution.set_hostname(launcher) execution.set_state("running" if state == "active" else state) - if self.kfp: - writer = AsyncLogWriter(self._db_conn, runobj) - status = self._get_k8s().watch( - launcher, meta.namespace, writer=writer - ) - logger.info(f"MpiJob {meta.name} finished with state {status}") - if status == "succeeded": - execution.set_state("completed") - else: - execution.set_state( - "error", - f"MpiJob {meta.name} finished with state {status}", - ) - else: - txt = f"MpiJob {meta.name} launcher pod {launcher} state {state}" - logger.info(txt) - runobj.status.status_text = txt + txt = f"MpiJob {meta.name} launcher pod {launcher} state {state}" + logger.info(txt) + runobj.status.status_text = txt + else: pods_phases = self.get_pods(meta.name, meta.namespace) txt = f"MpiJob status unknown or failed, check pods: {pods_phases}" logger.warning(txt) runobj.status.status_text = txt - if self.kfp: - execution.set_state("error", txt) return None @@ -288,8 +273,6 @@ def _get_launcher(self, name, namespace=None): if not pods: logger.error("no pod matches that job name") return - # TODO: Why was this here? - # k8s = self._get_k8s() return list(pods.items())[0] def with_tracing( diff --git a/mlrun/runtimes/mpijob/v1.py b/mlrun/runtimes/mpijob/v1.py index 06e3762823e..7bcc643a37c 100644 --- a/mlrun/runtimes/mpijob/v1.py +++ b/mlrun/runtimes/mpijob/v1.py @@ -361,6 +361,15 @@ def _get_object_label_selector(object_id: str) -> str: def _get_possible_mlrun_class_label_values() -> typing.List[str]: return ["mpijob"] + @staticmethod + def _get_run_completion_updates(run: dict) -> dict: + + # TODO: add a 'workers' section in run objects state, each worker will update its state while + # the run state will be resolved by the server. + # update the run object state if empty so that it won't default to 'created' state + update_in(run, "status.state", "running", append=False, replace=False) + return {} + @staticmethod def _get_crd_info() -> typing.Tuple[str, str, str]: return ( diff --git a/mlrun/runtimes/mpijob/v1alpha1.py b/mlrun/runtimes/mpijob/v1alpha1.py index 1db61ac5e54..c1ba5902bb3 100644 --- a/mlrun/runtimes/mpijob/v1alpha1.py +++ b/mlrun/runtimes/mpijob/v1alpha1.py @@ -194,6 +194,15 @@ def _get_object_label_selector(object_id: str) -> str: def _get_possible_mlrun_class_label_values() -> typing.List[str]: return ["mpijob"] + @staticmethod + def _get_run_completion_updates(run: dict) -> dict: + + # TODO: add a 'workers' section in run objects state, each worker will update its state while + # the run state will be resolved by the server. + # update the run object state if empty so that it won't default to 'created' state + update_in(run, "status.state", "running", append=False, replace=False) + return {} + @staticmethod def _get_crd_info() -> typing.Tuple[str, str, str]: return ( diff --git a/tests/api/runtimes/base.py b/tests/api/runtimes/base.py index 165148bb990..41ff684f25a 100644 --- a/tests/api/runtimes/base.py +++ b/tests/api/runtimes/base.py @@ -340,6 +340,9 @@ def _mock_get_logger_pods(self): get_k8s().v1api.list_namespaced_pod = unittest.mock.Mock( return_value=client.V1PodList(items=[]) ) + get_k8s().v1api.read_namespaced_pod_log = unittest.mock.Mock( + return_value="Mocked pod logs" + ) def _mock_create_namespaced_custom_object(self): def _generate_custom_object( @@ -388,6 +391,7 @@ def execute_function(self, runtime, **kwargs): def _reset_mocks(self): get_k8s().v1api.create_namespaced_pod.reset_mock() get_k8s().v1api.list_namespaced_pod.reset_mock() + get_k8s().v1api.read_namespaced_pod_log.reset_mock() def _reset_custom_object_mocks(self): mlrun.api.utils.singletons.k8s.get_k8s().crdapi.create_namespaced_custom_object.reset_mock() diff --git a/tests/api/runtimes/test_mpijob.py b/tests/api/runtimes/test_mpijob.py new file mode 100644 index 00000000000..e220e70033b --- /dev/null +++ b/tests/api/runtimes/test_mpijob.py @@ -0,0 +1,110 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import typing +import unittest.mock + +from kubernetes import client as k8s_client + +import mlrun.runtimes.pod +from mlrun import code_to_function, mlconf +from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.runtimes.constants import MPIJobCRDVersions +from tests.api.runtimes.base import TestRuntimeBase + + +class TestMpiV1Runtime(TestRuntimeBase): + def custom_setup(self): + self.runtime_kind = "mpijob" + self.code_handler = "test_func" + self.name = "test-mpi-v1" + mlconf.mpijob_crd_version = MPIJobCRDVersions.v1 + + def test_run_v1_sanity(self): + self._mock_list_pods() + self._mock_create_namespaced_custom_object() + self._mock_get_namespaced_custom_object() + mpijob_function = self._generate_runtime(self.runtime_kind) + mpijob_function.deploy() + run = mpijob_function.run( + artifact_path="v3io:///mypath", + watch=False, + ) + + assert run.status.state == "running" + + def _mock_get_namespaced_custom_object(self, workers=1): + get_k8s().crdapi.get_namespaced_custom_object = unittest.mock.Mock( + return_value={ + "status": { + "replicaStatuses": { + "Launcher": { + "active": 1, + }, + "Worker": { + "active": workers, + }, + } + }, + } + ) + + def _mock_list_pods(self, workers=1, pods=None, phase="Running"): + if pods is None: + pods = [self._get_worker_pod(phase=phase)] * workers + pods += [self._get_launcher_pod(phase=phase)] + get_k8s().list_pods = unittest.mock.Mock(return_value=pods) + + def _get_worker_pod(self, phase="Running"): + return k8s_client.V1Pod( + metadata=k8s_client.V1ObjectMeta( + labels={ + "kind": "mpijob", + "owner": "tester", + "v3io_user": "tester", + "mpijob": "v1/mpi-job-role=worker", + }, + name=self.name, + ), + status=k8s_client.V1PodStatus(phase=phase), + ) + + def _get_launcher_pod(self, phase="Running"): + return k8s_client.V1Pod( + metadata=k8s_client.V1ObjectMeta( + labels={ + "kind": "mpijob", + "owner": "tester", + "v3io_user": "tester", + "mpijob": "v1/mpi-job-role=launcher", + }, + name=self.name, + ), + status=k8s_client.V1PodStatus(phase=phase), + ) + + def _generate_runtime( + self, kind=None, labels=None + ) -> typing.Union[mlrun.runtimes.MpiRuntimeV1, mlrun.runtimes.MpiRuntimeV1Alpha1]: + runtime = code_to_function( + name=self.name, + project=self.project, + filename=self.code_filename, + handler=self.code_handler, + kind=kind or self.runtime_kind, + image=self.image_name, + description="test mpijob", + labels=labels, + ) + return runtime diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index 64cffd6a74f..6a6993d2451 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -45,6 +45,7 @@ from mlrun.runtimes import BaseRuntime from mlrun.runtimes.function import NuclioStatus from mlrun.runtimes.utils import global_context +from mlrun.utils import update_in from tests.conftest import logs_path, results, root_path, rundb_path session_maker: Callable @@ -280,6 +281,17 @@ def get_builder_status( ): return "ready", last_log_timestamp + def update_run(self, updates: dict, uid, project="", iter=0): + state = self._function.get("state", {}) + update_in(state, "status.state", updates) + update_in(state, "status.results", updates) + update_in(state, "status.start_time", updates) + update_in(state, "status.last_update", updates) + update_in(state, "status.error", updates) + update_in(state, "status.commit", updates) + update_in(state, "status.iterations", updates) + self._function["state"] = state + def assert_no_mount_or_creds_configured(self): env_list = self._function["spec"]["env"] env_params = [item["name"] for item in env_list] diff --git a/tests/run/test_run.py b/tests/run/test_run.py index 320b873f932..3033f752043 100644 --- a/tests/run/test_run.py +++ b/tests/run/test_run.py @@ -162,6 +162,10 @@ def test_local_context(): project_name = "xtst" mlrun.mlconf.artifact_path = out_path context = mlrun.get_or_create_ctx("xx", project=project_name, upload_artifacts=True) + db = mlrun.get_run_db() + run = db.read_run(context._uid, project=project_name) + assert run["status"]["state"] == "running", "run status not updated in db" + with context: context.log_artifact("xx", body="123", local_path="a.txt") context.log_model("mdl", body="456", model_file="mdl.pkl", artifact_path="+/mm") @@ -172,9 +176,8 @@ def test_local_context(): assert context._state == "completed", "task did not complete" - db = mlrun.get_run_db() run = db.read_run(context._uid, project=project_name) - assert run["status"]["state"] == "completed", "run status not updated in db" + assert run["status"]["state"] == "running", "run status was updated in db" assert ( run["status"]["artifacts"][0]["metadata"]["key"] == "xx" ), "artifact not updated in db" diff --git a/tests/system/runtimes/assets/mpijob_function.py b/tests/system/runtimes/assets/mpijob_function.py new file mode 100644 index 00000000000..22b9663d9ba --- /dev/null +++ b/tests/system/runtimes/assets/mpijob_function.py @@ -0,0 +1,35 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time + +from mpi4py import MPI + +import mlrun + + +@mlrun.handler(outputs=["time", "result"]) +def handler(context: mlrun.MLClientCtx): + # Start the timer: + run_time = time.time() + + # Get MPI rank: + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + # Log the values (only from root rank (#0) in mpijob): + if rank == 0: + time.sleep(1) + run_time = time.time() - run_time + return run_time, 1000 diff --git a/tests/system/runtimes/test_mpijob.py b/tests/system/runtimes/test_mpijob.py new file mode 100644 index 00000000000..9d33dfe8d5d --- /dev/null +++ b/tests/system/runtimes/test_mpijob.py @@ -0,0 +1,45 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import mlrun +import tests.system.base +from mlrun.runtimes.constants import RunStates + + +@tests.system.base.TestMLRunSystem.skip_test_if_env_not_configured +class TestMpiJobRuntime(tests.system.base.TestMLRunSystem): + project_name = "does-not-exist-mpijob" + + def test_run_state_completion(self): + code_path = str(self.assets_path / "mpijob_function.py") + + # Create the open mpi function: + mpijob_function = mlrun.code_to_function( + name="mpijob_test", + kind="mpijob", + handler="handler", + project=self.project_name, + filename=code_path, + image="mlrun/ml-models", + requirements=["mpi4py"], + ) + mpijob_function.spec.replicas = 4 + + mpijob_run = mpijob_function.run(auto_build=True) + assert mpijob_run.status.state == RunStates.completed + + mpijob_time = mpijob_run.status.results["time"] + mpijob_result = mpijob_run.status.results["result"] + assert mpijob_time is not None + assert mpijob_result == 1000