Skip to content

Commit

Permalink
[Project] Sync functions flag in load_project (mlrun#2871)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr authored Jan 16, 2023
1 parent 321c91a commit 2df68aa
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
42 changes: 27 additions & 15 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def load_project(
clone: bool = False,
user_project: bool = False,
save: bool = True,
sync_functions: bool = False,
) -> "MlrunProject":
"""Load an MLRun project from git or tar or dir
Expand All @@ -218,20 +219,21 @@ def load_project(
project = load_project("./demo_proj", "git://github.com/mlrun/project-demo.git")
project.run("main", arguments={'data': data_url})
:param context: project local directory path
:param url: name (in DB) or git or tar.gz or .zip sources archive path e.g.:
git://github.com/mlrun/demo-xgb-project.git
http://mysite/archived-project.zip
<project-name>
The git project should include the project yaml file.
If the project yaml file is in a sub-directory, must specify the sub-directory.
:param name: project name
:param secrets: key:secret dict or SecretsStore used to download sources
:param init_git: if True, will git init the context dir
:param subpath: project subpath (within the archive)
:param clone: if True, always clone (delete any existing content)
:param user_project: add the current user name to the project name (for db:// prefixes)
:param save: whether to save the created project and artifact in the DB
:param context: project local directory path
:param url: name (in DB) or git or tar.gz or .zip sources archive path e.g.:
git://github.com/mlrun/demo-xgb-project.git
http://mysite/archived-project.zip
<project-name>
The git project should include the project yaml file.
If the project yaml file is in a sub-directory, must specify the sub-directory.
:param name: project name
:param secrets: key:secret dict or SecretsStore used to download sources
:param init_git: if True, will git init the context dir
:param subpath: project subpath (within the archive)
:param clone: if True, always clone (delete any existing content)
:param user_project: add the current user name to the project name (for db:// prefixes)
:param save: whether to save the created project and artifact in the DB
:param sync_functions: sync the project's functions into the project object (will be saved to the DB if save=True)
:returns: project object
"""
Expand Down Expand Up @@ -280,9 +282,15 @@ def load_project(
project.spec.branch = repo.active_branch.name
except Exception:
pass

if save and mlrun.mlconf.dbpath:
project.save()
project.register_artifacts()
if sync_functions:
project.sync_functions(names=project.get_function_names(), save=True)

elif sync_functions:
project.sync_functions(names=project.get_function_names(), save=False)

_set_as_current_default_project(project)

Expand Down Expand Up @@ -1699,6 +1707,10 @@ def get_function_objects(self) -> typing.Dict[str, mlrun.runtimes.BaseRuntime]:
self.sync_functions()
return FunctionsDict(self)

def get_function_names(self) -> typing.List[str]:
"""get a list of all the project function names"""
return [func["name"] for func in self.spec.functions]

def pull(self, branch=None, remote=None):
"""pull/update sources from git or tar into the context dir
Expand Down Expand Up @@ -2839,7 +2851,7 @@ def _init_function_from_dict(f, project, name=None):
raise ValueError(
"function with db:// or hub:// url or .yaml file, does not support tag value "
)
func = import_function(url)
func = import_function(url, new_name=name)
if image:
func.spec.image = image
elif url.endswith(".ipynb"):
Expand Down
5 changes: 5 additions & 0 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,14 @@ def __init__(self):
self.kind = "http"
self._pipeline = None
self._function = None
self._artifact = None

def reset(self):
self._function = None
self._pipeline = None
self._project_name = None
self._project = None
self._artifact = None

# Expected to return a hash-key
def store_function(self, function, name, project="", tag=None, versioned=False):
Expand All @@ -208,6 +210,9 @@ def store_run(self, struct, uid, project="", iter=0):
}
}

def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""):
self._artifact = artifact

def get_function(self, function, project, tag):
return {
"name": function,
Expand Down
40 changes: 40 additions & 0 deletions tests/projects/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,46 @@ def test_load_project(
assert os.path.exists(os.path.join(context, project_file))


@pytest.mark.parametrize(
"sync,expected_num_of_funcs, save",
[
(
False,
0,
False,
),
(
True,
4,
False,
),
(
True,
4,
True,
),
],
)
def test_load_project_and_sync_functions(
context, rundb_mock, sync, expected_num_of_funcs, save
):
url = "git://github.com/mlrun/project-demo.git"
project = mlrun.load_project(
context=str(context), url=url, sync_functions=sync, save=save
)
assert len(project.spec._function_objects) == expected_num_of_funcs

if sync:
function_names = project.get_function_names()
assert len(function_names) == expected_num_of_funcs
for func in function_names:
fn = project.get_function(func)
assert fn.metadata.name == func, "func did not return"

if save:
assert rundb_mock._function is not None


def _assert_project_function_objects(project, expected_function_objects):
project_function_objects = project.spec._function_objects
assert len(project_function_objects) == len(expected_function_objects)
Expand Down

0 comments on commit 2df68aa

Please sign in to comment.