diff --git a/automation/package_test/test.py b/automation/package_test/test.py index 8f5766e8f1..c517ee2270 100644 --- a/automation/package_test/test.py +++ b/automation/package_test/test.py @@ -42,7 +42,9 @@ def __init__(self): } def run(self): - self._logger.info("Running package tests",) + self._logger.info( + "Running package tests", + ) results = {} for extra, extra_tests_data in self._extras_tests_data.items(): @@ -84,13 +86,15 @@ def _run_test(self, test_function, extra, results, test_key): def _test_extra_imports(self, extra): self._logger.debug( - "Testing extra imports", extra=extra, + "Testing extra imports", + extra=extra, ) test_command = ( f"python -c '{self._extras_tests_data[extra]['import_test_command']}'" ) self._run_command( - test_command, run_in_venv=True, + test_command, + run_in_venv=True, ) if "api" not in extra: # When api is not in the extra it's an extra purposed for the client usage @@ -104,33 +108,47 @@ def _test_extra_imports(self, extra): def _test_requirements_conflicts(self, extra): self._logger.debug( - "Testing requirements conflicts", extra=extra, + "Testing requirements conflicts", + extra=extra, ) self._run_command( - "pip install pipdeptree", run_in_venv=True, + "pip install pipdeptree", + run_in_venv=True, ) self._run_command( - "pipdeptree --warn fail", run_in_venv=True, + "pipdeptree --warn fail", + run_in_venv=True, ) def _create_venv(self): - self._logger.debug("Creating venv",) - self._run_command("python -m venv test-venv",) + self._logger.debug( + "Creating venv", + ) + self._run_command( + "python -m venv test-venv", + ) def _clean_venv(self): - self._logger.debug("Cleaning venv",) - self._run_command("rm -rf test-venv",) + self._logger.debug( + "Cleaning venv", + ) + self._run_command( + "rm -rf test-venv", + ) def _install_extra(self, extra): self._logger.debug( - "Installing extra", extra=extra, + "Installing extra", + extra=extra, ) self._run_command( - "python -m pip install --upgrade pip~=22.0.0", run_in_venv=True, + "python -m pip install --upgrade pip~=22.0.0", + run_in_venv=True, ) self._run_command( - f"pip install '.{extra}'", run_in_venv=True, + f"pip install '.{extra}'", + run_in_venv=True, ) def _run_command(self, command, run_in_venv=False, env=None): diff --git a/automation/release_notes/generate.py b/automation/release_notes/generate.py index b29b4d9913..fc6b61cb38 100644 --- a/automation/release_notes/generate.py +++ b/automation/release_notes/generate.py @@ -28,7 +28,10 @@ class ReleaseNotesGenerator: ) def __init__( - self, release: str, previous_release: str, release_branch: str, + self, + release: str, + previous_release: str, + release_branch: str, ): self._logger = logger self._release = release @@ -217,7 +220,9 @@ def main(): @click.argument("previous-release", type=str, required=True) @click.argument("release-branch", type=str, required=False, default="master") def run( - release: str, previous_release: str, release_branch: str, + release: str, + previous_release: str, + release_branch: str, ): release_notes_generator = ReleaseNotesGenerator( release, previous_release, release_branch diff --git a/automation/requirements.txt b/automation/requirements.txt index fe10aa3af2..1bec788d28 100644 --- a/automation/requirements.txt +++ b/automation/requirements.txt @@ -1,4 +1,4 @@ -click~=7.0 +click~=8.0 paramiko~=2.7 semver~=2.13 requests~=2.22 diff --git a/automation/system_test/prepare.py b/automation/system_test/prepare.py index 22d04b0a47..617e5870be 100644 --- a/automation/system_test/prepare.py +++ b/automation/system_test/prepare.py @@ -287,13 +287,17 @@ def extract_version_from_release(release): def _prepare_test_env(self): self._run_command( - "mkdir", args=["-p", str(self.Constants.workdir)], + "mkdir", + args=["-p", str(self.Constants.workdir)], ) contents = yaml.safe_dump(self._env_config) filepath = str(self.Constants.system_tests_env_yaml) self._logger.debug("Populating system tests env.yml", filepath=filepath) self._run_command( - "cat > ", args=[filepath], stdin=contents, local=True, + "cat > ", + args=[filepath], + stdin=contents, + local=True, ) def _override_mlrun_api_env(self): @@ -326,7 +330,8 @@ def _override_mlrun_api_env(self): ) self._run_command( - "kubectl", args=["apply", "-f", manifest_file_name], + "kubectl", + args=["apply", "-f", manifest_file_name], ) def _download_provctl(self): diff --git a/automation/version/version_file.py b/automation/version/version_file.py index 2a1e91a3ad..dee767ffed 100644 --- a/automation/version/version_file.py +++ b/automation/version/version_file.py @@ -57,11 +57,18 @@ def _run_command(command, args=None): if sys.version_info[0] >= 3: process = subprocess.run( - command, shell=True, check=True, capture_output=True, encoding="utf-8", + command, + shell=True, + check=True, + capture_output=True, + encoding="utf-8", ) output = process.stdout else: - output = subprocess.check_output(command, shell=True,) + output = subprocess.check_output( + command, + shell=True, + ) return output diff --git a/dev-requirements.txt b/dev-requirements.txt index fcd762fc5b..7e637b5b61 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,7 @@ pytest~=5.4 twine~=3.1 # TODO: bump black when it become stable -black<=19.10b0 +black~=22.0 flake8~=3.8 pytest-asyncio~=0.15.0 pytest-alembic~=0.4.0 diff --git a/dockerfiles/mlrun-api/requirements.txt b/dockerfiles/mlrun-api/requirements.txt index f15f33fd0e..cb44c97a5c 100644 --- a/dockerfiles/mlrun-api/requirements.txt +++ b/dockerfiles/mlrun-api/requirements.txt @@ -1,4 +1,4 @@ -uvicorn~=0.12.0 +uvicorn~=0.17.0 dask-kubernetes~=0.11.0 apscheduler~=3.6 sqlite3-to-mysql~=1.4 diff --git a/mlrun/__main__.py b/mlrun/__main__.py index c032d1488f..7d880a06b9 100644 --- a/mlrun/__main__.py +++ b/mlrun/__main__.py @@ -72,7 +72,7 @@ def main(): @click.option( "--param", "-p", - default="", + default=[], multiple=True, help="parameter name and value tuples, e.g. -p x=37 -p y='text'", ) @@ -81,7 +81,10 @@ def main(): @click.option("--in-path", help="default input path/url (prefix) for artifact") @click.option("--out-path", help="default output path/url (prefix) for artifact") @click.option( - "--secrets", "-s", multiple=True, help="secrets file= or env=ENV_KEY1,.." + "--secrets", + "-s", + multiple=True, + help="secrets file= or env=ENV_KEY1,..", ) @click.option("--uid", help="unique run ID") @click.option("--name", help="run name") @@ -97,7 +100,7 @@ def main(): @click.option( "--hyperparam", "-x", - default="", + default=[], multiple=True, help="hyper parameters (will expand to multiple tasks) e.g. --hyperparam p2=[1,2,3]", ) @@ -115,7 +118,9 @@ def main(): help="hyperparam tuning strategy list | grid | random", ) @click.option( - "--hyper-param-options", default="", help="hyperparam options json string", + "--hyper-param-options", + default="", + help="hyperparam options json string", ) @click.option( "--func-url", @@ -379,7 +384,7 @@ def run( @click.option( "--command", "-c", - default="", + default=[], multiple=True, help="build commands, e.g. '-c pip install pandas'", ) @@ -774,7 +779,7 @@ def logs(uid, project, offset, db, watch): @click.option( "--arguments", "-a", - default="", + default=[], multiple=True, help="Kubeflow pipeline arguments name and value tuples (with -r flag), e.g. -a x=6", ) @@ -782,12 +787,15 @@ def logs(uid, project, offset, db, watch): @click.option( "--param", "-x", - default="", + default=[], multiple=True, help="mlrun project parameter name and value tuples, e.g. -p x=37 -p y='text'", ) @click.option( - "--secrets", "-s", multiple=True, help="secrets file= or env=ENV_KEY1,.." + "--secrets", + "-s", + multiple=True, + help="secrets file= or env=ENV_KEY1,..", ) @click.option("--namespace", help="k8s namespace") @click.option("--db", help="api and db service path/url") diff --git a/mlrun/api/api/endpoints/artifacts.py b/mlrun/api/api/endpoints/artifacts.py index 3575326935..ffb534c081 100644 --- a/mlrun/api/api/endpoints/artifacts.py +++ b/mlrun/api/api/endpoints/artifacts.py @@ -71,14 +71,19 @@ def list_artifact_tags( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) tag_tuples = get_db().list_artifact_tags(db_session, project) artifact_key_to_tag = {tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples} allowed_artifact_keys = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.artifact, list(artifact_key_to_tag.keys()), - lambda artifact_key: (project, artifact_key,), + lambda artifact_key: ( + project, + artifact_key, + ), auth_info, ) tags = [ @@ -150,7 +155,9 @@ def list_artifacts( if project is None: project = config.default_project mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) artifacts = mlrun.api.crud.Artifacts().list_artifacts( diff --git a/mlrun/api/api/endpoints/background_tasks.py b/mlrun/api/api/endpoints/background_tasks.py index 527004b0c2..f1ff12fa4e 100644 --- a/mlrun/api/api/endpoints/background_tasks.py +++ b/mlrun/api/api/endpoints/background_tasks.py @@ -35,7 +35,8 @@ def get_project_background_task( @router.get( - "/background-tasks/{name}", response_model=mlrun.api.schemas.BackgroundTask, + "/background-tasks/{name}", + response_model=mlrun.api.schemas.BackgroundTask, ) def get_background_task( name: str, diff --git a/mlrun/api/api/endpoints/client_spec.py b/mlrun/api/api/endpoints/client_spec.py index 57f7a9dc42..d4fcc58152 100644 --- a/mlrun/api/api/endpoints/client_spec.py +++ b/mlrun/api/api/endpoints/client_spec.py @@ -7,7 +7,8 @@ @router.get( - "/client-spec", response_model=mlrun.api.schemas.ClientSpec, + "/client-spec", + response_model=mlrun.api.schemas.ClientSpec, ) def get_client_spec(): return mlrun.api.crud.ClientSpec().get_client_spec() diff --git a/mlrun/api/api/endpoints/feature_store.py b/mlrun/api/api/endpoints/feature_store.py index 7d82308024..159a5d3cfe 100644 --- a/mlrun/api/api/endpoints/feature_store.py +++ b/mlrun/api/api/endpoints/feature_store.py @@ -40,7 +40,10 @@ def create_feature_set( auth_info, ) feature_set_uid = mlrun.api.crud.FeatureStore().create_feature_set( - db_session, project, feature_set, versioned, + db_session, + project, + feature_set, + versioned, ) return mlrun.api.crud.FeatureStore().get_feature_set( @@ -77,10 +80,20 @@ def store_feature_set( ) tag, uid = parse_reference(reference) uid = mlrun.api.crud.FeatureStore().store_feature_set( - db_session, project, name, feature_set, tag, uid, versioned, + db_session, + project, + name, + feature_set, + tag, + uid, + versioned, ) return mlrun.api.crud.FeatureStore().get_feature_set( - db_session, project, feature_set.metadata.name, tag, uid, + db_session, + project, + feature_set.metadata.name, + tag, + uid, ) @@ -105,7 +118,13 @@ def patch_feature_set( ) tag, uid = parse_reference(reference) mlrun.api.crud.FeatureStore().patch_feature_set( - db_session, project, name, feature_set_update, tag, uid, patch_mode, + db_session, + project, + name, + feature_set_update, + tag, + uid, + patch_mode, ) return Response(status_code=HTTPStatus.OK.value) @@ -183,7 +202,9 @@ def list_feature_sets( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) feature_sets = mlrun.api.crud.FeatureStore().list_feature_sets( db_session, @@ -202,7 +223,10 @@ def list_feature_sets( feature_sets = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.feature_set, feature_sets.feature_sets, - lambda feature_set: (feature_set.metadata.project, feature_set.metadata.name,), + lambda feature_set: ( + feature_set.metadata.project, + feature_set.metadata.name, + ), auth_info, ) return mlrun.api.schemas.FeatureSetsOutput(feature_sets=feature_sets) @@ -223,16 +247,22 @@ def list_feature_sets_tags( "Listing a specific feature set tags is not supported, set name to *" ) mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) tag_tuples = mlrun.api.crud.FeatureStore().list_feature_sets_tags( - db_session, project, + db_session, + project, ) feature_set_name_to_tag = {tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples} allowed_feature_set_names = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.feature_set, list(feature_set_name_to_tag.keys()), - lambda feature_set_name: (project, feature_set_name,), + lambda feature_set_name: ( + project, + feature_set_name, + ), auth_info, ) tags = { @@ -378,7 +408,9 @@ def list_features( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) features = mlrun.api.crud.FeatureStore().list_features( db_session, project, name, tag, entities, labels @@ -405,7 +437,9 @@ def list_entities( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) entities = mlrun.api.crud.FeatureStore().list_entities( db_session, project, name, tag, labels @@ -446,7 +480,10 @@ def create_feature_vector( auth_info, project, feature_vector.dict() ) feature_vector_uid = mlrun.api.crud.FeatureStore().create_feature_vector( - db_session, project, feature_vector, versioned, + db_session, + project, + feature_vector, + versioned, ) return mlrun.api.crud.FeatureStore().get_feature_vector( @@ -507,7 +544,9 @@ def list_feature_vectors( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) feature_vectors = mlrun.api.crud.FeatureStore().list_feature_vectors( db_session, @@ -552,10 +591,13 @@ def list_feature_vectors_tags( "Listing a specific feature vector tags is not supported, set name to *" ) mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) tag_tuples = mlrun.api.crud.FeatureStore().list_feature_vectors_tags( - db_session, project, + db_session, + project, ) feature_vector_name_to_tag = { tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples @@ -563,7 +605,10 @@ def list_feature_vectors_tags( allowed_feature_vector_names = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, list(feature_vector_name_to_tag.keys()), - lambda feature_vector_name: (project, feature_vector_name,), + lambda feature_vector_name: ( + project, + feature_vector_name, + ), auth_info, ) tags = { @@ -602,7 +647,13 @@ def store_feature_vector( ) tag, uid = parse_reference(reference) uid = mlrun.api.crud.FeatureStore().store_feature_vector( - db_session, project, name, feature_vector, tag, uid, versioned, + db_session, + project, + name, + feature_vector, + tag, + uid, + versioned, ) return mlrun.api.crud.FeatureStore().get_feature_vector( @@ -634,7 +685,13 @@ def patch_feature_vector( ) tag, uid = parse_reference(reference) mlrun.api.crud.FeatureStore().patch_feature_vector( - db_session, project, name, feature_vector_patch, tag, uid, patch_mode, + db_session, + project, + name, + feature_vector_patch, + tag, + uid, + patch_mode, ) return Response(status_code=HTTPStatus.OK.value) diff --git a/mlrun/api/api/endpoints/frontend_spec.py b/mlrun/api/api/endpoints/frontend_spec.py index 50834a8a56..7975847be7 100644 --- a/mlrun/api/api/endpoints/frontend_spec.py +++ b/mlrun/api/api/endpoints/frontend_spec.py @@ -15,7 +15,8 @@ @router.get( - "/frontend-spec", response_model=mlrun.api.schemas.FrontendSpec, + "/frontend-spec", + response_model=mlrun.api.schemas.FrontendSpec, ) def get_frontend_spec( auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( @@ -33,8 +34,14 @@ def get_frontend_spec( feature_flags = _resolve_feature_flags() registry, repository = mlrun.utils.helpers.get_parsed_docker_registry() repository = mlrun.utils.helpers.get_docker_repository_or_default(repository) - function_deployment_target_image_template = mlrun.runtimes.utils.fill_function_image_name_template( - f"{registry}/", repository, "{project}", "{name}", "{tag}", + function_deployment_target_image_template = ( + mlrun.runtimes.utils.fill_function_image_name_template( + f"{registry}/", + repository, + "{project}", + "{name}", + "{tag}", + ) ) registries_to_enforce_prefix = ( mlrun.runtimes.utils.resolve_function_target_image_registries_to_enforce_prefix() diff --git a/mlrun/api/api/endpoints/functions.py b/mlrun/api/api/endpoints/functions.py index f1dd73580a..b664820ab3 100644 --- a/mlrun/api/api/endpoints/functions.py +++ b/mlrun/api/api/endpoints/functions.py @@ -143,7 +143,9 @@ def list_functions( if project is None: project = config.default_project mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) functions = mlrun.api.crud.Functions().list_functions( db_session, project, name, tag, labels @@ -348,7 +350,12 @@ def build_status( # the DB with intermediate or unusable versions, only successfully deployed versions versioned = True mlrun.api.crud.Functions().store_function( - db_session, fn, name, project, tag, versioned=versioned, + db_session, + fn, + name, + project, + tag, + versioned=versioned, ) return Response( content=text, @@ -428,7 +435,12 @@ def build_status( if state == mlrun.api.schemas.FunctionState.ready: versioned = True mlrun.api.crud.Functions().store_function( - db_session, fn, name, project, tag, versioned=versioned, + db_session, + fn, + name, + project, + tag, + versioned=versioned, ) return Response( @@ -681,7 +693,10 @@ def _process_model_monitoring_secret(db_session, project_name: str, secret_key: provider = SecretProviderName.kubernetes secret_value = Secrets().get_secret( - project_name, provider, secret_key, allow_secrets_from_k8s=True, + project_name, + provider, + secret_key, + allow_secrets_from_k8s=True, ) user_provided_key = secret_value is not None internal_key_name = Secrets().generate_model_monitoring_secret_key(secret_key) diff --git a/mlrun/api/api/endpoints/grafana_proxy.py b/mlrun/api/api/endpoints/grafana_proxy.py index e3aeff6504..69f504b05f 100644 --- a/mlrun/api/api/endpoints/grafana_proxy.py +++ b/mlrun/api/api/endpoints/grafana_proxy.py @@ -128,7 +128,9 @@ def grafana_list_endpoints( if project: mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) endpoint_list = mlrun.api.crud.ModelEndpoints().list_endpoints( auth_info=auth_info, @@ -143,7 +145,10 @@ def grafana_list_endpoints( allowed_endpoints = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, endpoint_list.endpoints, - lambda _endpoint: (_endpoint.metadata.project, _endpoint.metadata.uid,), + lambda _endpoint: ( + _endpoint.metadata.project, + _endpoint.metadata.uid, + ), auth_info, ) endpoint_list.endpoints = allowed_endpoints @@ -341,7 +346,9 @@ def grafana_incoming_features( _, container, path = parse_model_endpoint_store_prefix(path) client = get_frames_client( - token=auth_info.data_session, address=config.v3io_framesd, container=container, + token=auth_info.data_session, + address=config.v3io_framesd, + container=container, ) data: pd.DataFrame = client.read( @@ -354,7 +361,7 @@ def grafana_incoming_features( ) data.drop(["endpoint_id"], axis=1, inplace=True, errors="ignore") - data.index = data.index.astype(np.int64) // 10 ** 6 + data.index = data.index.astype(np.int64) // 10**6 for feature, indexed_values in data.to_dict().items(): target = GrafanaTimeSeriesTarget(target=feature) diff --git a/mlrun/api/api/endpoints/healthz.py b/mlrun/api/api/endpoints/healthz.py index 041ad052a3..820a719e14 100644 --- a/mlrun/api/api/endpoints/healthz.py +++ b/mlrun/api/api/endpoints/healthz.py @@ -7,7 +7,8 @@ @router.get( - "/healthz", response_model=mlrun.api.schemas.ClientSpec, + "/healthz", + response_model=mlrun.api.schemas.ClientSpec, ) def health(): diff --git a/mlrun/api/api/endpoints/logs.py b/mlrun/api/api/endpoints/logs.py index 7c05299de8..432e0a60db 100644 --- a/mlrun/api/api/endpoints/logs.py +++ b/mlrun/api/api/endpoints/logs.py @@ -30,7 +30,11 @@ async def store_log( ) body = await request.body() await fastapi.concurrency.run_in_threadpool( - mlrun.api.crud.Logs().store_log, body, project, uid, append, + mlrun.api.crud.Logs().store_log, + body, + project, + uid, + append, ) return {} diff --git a/mlrun/api/api/endpoints/marketplace.py b/mlrun/api/api/endpoints/marketplace.py index c90a52714d..5fa158aec7 100644 --- a/mlrun/api/api/endpoints/marketplace.py +++ b/mlrun/api/api/endpoints/marketplace.py @@ -45,7 +45,8 @@ def create_source( @router.get( - path="/marketplace/sources", response_model=List[IndexedMarketplaceSource], + path="/marketplace/sources", + response_model=List[IndexedMarketplaceSource], ) def list_sources( db_session: Session = Depends(mlrun.api.api.deps.get_db_session), @@ -63,7 +64,8 @@ def list_sources( @router.delete( - path="/marketplace/sources/{source_name}", status_code=HTTPStatus.NO_CONTENT.value, + path="/marketplace/sources/{source_name}", + status_code=HTTPStatus.NO_CONTENT.value, ) def delete_source( source_name: str, @@ -83,7 +85,8 @@ def delete_source( @router.get( - path="/marketplace/sources/{source_name}", response_model=IndexedMarketplaceSource, + path="/marketplace/sources/{source_name}", + response_model=IndexedMarketplaceSource, ) def get_source( source_name: str, @@ -127,7 +130,8 @@ def store_source( @router.get( - path="/marketplace/sources/{source_name}/items", response_model=MarketplaceCatalog, + path="/marketplace/sources/{source_name}/items", + response_model=MarketplaceCatalog, ) def get_catalog( source_name: str, @@ -180,7 +184,9 @@ def get_item( ) -@router.get("/marketplace/sources/{source_name}/item-object",) +@router.get( + "/marketplace/sources/{source_name}/item-object", +) def get_object( source_name: str, url: str, diff --git a/mlrun/api/api/endpoints/model_endpoints.py b/mlrun/api/api/endpoints/model_endpoints.py index f8281b1d4c..10bb28dfa3 100644 --- a/mlrun/api/api/endpoints/model_endpoints.py +++ b/mlrun/api/api/endpoints/model_endpoints.py @@ -108,25 +108,27 @@ def list_endpoints( ), ) -> ModelEndpointList: """ - Returns a list of endpoints of type 'ModelEndpoint', supports filtering by model, function, tag, - labels or top level. - If uids are passed: will return ModelEndpointList of endpoints with uid in uids - Labels can be used to filter on the existence of a label: - api/projects/{project}/model-endpoints/?label=mylabel + Returns a list of endpoints of type 'ModelEndpoint', supports filtering by model, function, tag, + labels or top level. + If uids are passed: will return ModelEndpointList of endpoints with uid in uids + Labels can be used to filter on the existence of a label: + api/projects/{project}/model-endpoints/?label=mylabel - Or on the value of a given label: - api/projects/{project}/model-endpoints/?label=mylabel=1 + Or on the value of a given label: + api/projects/{project}/model-endpoints/?label=mylabel=1 - Multiple labels can be queried in a single request by either using "&" separator: - api/projects/{project}/model-endpoints/?label=mylabel=1&label=myotherlabel=2 + Multiple labels can be queried in a single request by either using "&" separator: + api/projects/{project}/model-endpoints/?label=mylabel=1&label=myotherlabel=2 - Or by using a "," (comma) separator: - api/projects/{project}/model-endpoints/?label=mylabel=1,myotherlabel=2 - Top level: if true will return only routers and endpoint that are NOT children of any router - """ + Or by using a "," (comma) separator: + api/projects/{project}/model-endpoints/?label=mylabel=1,myotherlabel=2 + Top level: if true will return only routers and endpoint that are NOT children of any router + """ mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) endpoints = mlrun.api.crud.ModelEndpoints().list_endpoints( @@ -144,7 +146,10 @@ def list_endpoints( allowed_endpoints = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, endpoints.endpoints, - lambda _endpoint: (_endpoint.metadata.project, _endpoint.metadata.uid,), + lambda _endpoint: ( + _endpoint.metadata.project, + _endpoint.metadata.uid, + ), auth_info, ) diff --git a/mlrun/api/api/endpoints/operations.py b/mlrun/api/api/endpoints/operations.py index c5cf647ac8..c272a7bedd 100644 --- a/mlrun/api/api/endpoints/operations.py +++ b/mlrun/api/api/endpoints/operations.py @@ -23,14 +23,17 @@ }, ) def start_migration( - background_tasks: fastapi.BackgroundTasks, response: fastapi.Response, + background_tasks: fastapi.BackgroundTasks, + response: fastapi.Response, ): # we didn't yet decide who should have permissions to such actions, therefore no authorization at the moment # note in api.py we do declare to use the authenticate_request dependency - meaning we do have authentication global current_migration_background_task_name if mlrun.mlconf.httpdb.state == mlrun.api.schemas.APIStates.migrations_in_progress: - background_task = mlrun.api.utils.background_tasks.Handler().get_background_task( - current_migration_background_task_name + background_task = ( + mlrun.api.utils.background_tasks.Handler().get_background_task( + current_migration_background_task_name + ) ) response.status_code = http.HTTPStatus.ACCEPTED.value return background_task @@ -44,7 +47,8 @@ def start_migration( return fastapi.Response(status_code=http.HTTPStatus.OK.value) logger.info("Starting the migration process") background_task = mlrun.api.utils.background_tasks.Handler().create_background_task( - background_tasks, _perform_migration, + background_tasks, + _perform_migration, ) current_migration_background_task_name = background_task.metadata.name response.status_code = http.HTTPStatus.ACCEPTED.value diff --git a/mlrun/api/api/endpoints/pipelines.py b/mlrun/api/api/endpoints/pipelines.py index 550868a4f4..5707bc5a7f 100644 --- a/mlrun/api/api/endpoints/pipelines.py +++ b/mlrun/api/api/endpoints/pipelines.py @@ -43,7 +43,9 @@ def list_pipelines( namespace = config.namespace if project != "*": mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) total_size, next_page_token, runs = None, None, [] if get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): @@ -67,7 +69,10 @@ def list_pipelines( allowed_runs = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.pipeline, runs, - lambda run: (run["project"], run["id"],), + lambda run: ( + run["project"], + run["id"], + ), auth_info, ) if format_ == mlrun.api.schemas.PipelinesFormat.name_only: @@ -94,7 +99,11 @@ async def submit_pipeline_legacy( if namespace is None: namespace = config.namespace response = await _create_pipeline( - auth_info, request, namespace, experiment_name, run_name, + auth_info, + request, + namespace, + experiment_name, + run_name, ) return response diff --git a/mlrun/api/api/endpoints/projects.py b/mlrun/api/api/endpoints/projects.py index 89f0a903e0..4f362fc386 100644 --- a/mlrun/api/api/endpoints/projects.py +++ b/mlrun/api/api/endpoints/projects.py @@ -131,7 +131,9 @@ def get_project( # skip permission check if it's the leader if not _is_request_from_leader(auth_info.projects_role): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - name, mlrun.api.schemas.AuthorizationAction.read, auth_info, + name, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) return project @@ -199,8 +201,11 @@ def list_projects( auth_info.projects_role, auth_info.session, ) - allowed_project_names = mlrun.api.utils.auth.verifier.AuthVerifier().filter_projects_by_permissions( - projects_output.projects, auth_info, + allowed_project_names = ( + mlrun.api.utils.auth.verifier.AuthVerifier().filter_projects_by_permissions( + projects_output.projects, + auth_info, + ) ) return get_project_member().list_projects( db_session, diff --git a/mlrun/api/api/endpoints/runs.py b/mlrun/api/api/endpoints/runs.py index dce382ff3e..e012df2a50 100644 --- a/mlrun/api/api/endpoints/runs.py +++ b/mlrun/api/api/endpoints/runs.py @@ -49,7 +49,12 @@ async def store_run( logger.info("Storing run", data=data) await run_in_threadpool( - mlrun.api.crud.Runs().store_run, db_session, data, uid, iter, project, + mlrun.api.crud.Runs().store_run, + db_session, + data, + uid, + iter, + project, ) return {} @@ -78,7 +83,12 @@ async def update_run( log_and_raise(HTTPStatus.BAD_REQUEST.value, reason="bad JSON body") await run_in_threadpool( - mlrun.api.crud.Runs().update_run, db_session, project, uid, iter, data, + mlrun.api.crud.Runs().update_run, + db_session, + project, + uid, + iter, + data, ) return {} @@ -120,7 +130,10 @@ def delete_run( auth_info, ) mlrun.api.crud.Runs().delete_run( - db_session, uid, iter, project, + db_session, + uid, + iter, + project, ) return {} @@ -154,7 +167,9 @@ def list_runs( ): if project != "*": mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) runs = mlrun.api.crud.Runs().list_runs( db_session, @@ -221,6 +236,11 @@ def delete_runs( auth_info, ) mlrun.api.crud.Runs().delete_runs( - db_session, name, project, labels, state, days_ago, + db_session, + name, + project, + labels, + state, + days_ago, ) return {} diff --git a/mlrun/api/api/endpoints/runtime_resources.py b/mlrun/api/api/endpoints/runtime_resources.py index 2bfbf61fd7..01ac1bcd96 100644 --- a/mlrun/api/api/endpoints/runtime_resources.py +++ b/mlrun/api/api/endpoints/runtime_resources.py @@ -219,12 +219,22 @@ def _delete_runtime_resources( else: computed_label_selector = permissions_label_selector mlrun.api.crud.RuntimeResources().delete_runtime_resources( - db_session, kind, object_id, computed_label_selector, force, grace_period, + db_session, + kind, + object_id, + computed_label_selector, + force, + grace_period, ) if is_non_project_runtime_resource_exists: # delete one more time, without adding the allowed projects selector mlrun.api.crud.RuntimeResources().delete_runtime_resources( - db_session, kind, object_id, label_selector, force, grace_period, + db_session, + kind, + object_id, + label_selector, + force, + grace_period, ) if return_body: filtered_projects = copy.deepcopy(allowed_projects) @@ -261,7 +271,9 @@ def _list_runtime_resources( project, auth_info, label_selector, kind_filter, object_id ) return mlrun.api.crud.RuntimeResources().filter_and_format_grouped_by_project_runtime_resources_output( - grouped_by_project_runtime_resources_output, allowed_projects, group_by, + grouped_by_project_runtime_resources_output, + allowed_projects, + group_by, ) @@ -277,15 +289,19 @@ def _get_runtime_resources_allowed_projects( ]: if project != "*": mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput - grouped_by_project_runtime_resources_output = mlrun.api.crud.RuntimeResources().list_runtime_resources( - project, - kind, - object_id, - label_selector, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + grouped_by_project_runtime_resources_output = ( + mlrun.api.crud.RuntimeResources().list_runtime_resources( + project, + kind, + object_id, + label_selector, + mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + ) ) projects = [] is_non_project_runtime_resource_exists = False @@ -300,7 +316,10 @@ def _get_runtime_resources_allowed_projects( allowed_projects = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.runtime_resource, projects, - lambda project: (project, "",), + lambda project: ( + project, + "", + ), auth_info, action=action, ) @@ -311,5 +330,7 @@ def _get_runtime_resources_allowed_projects( ) -def _generate_label_selector_for_allowed_projects(allowed_projects: typing.List[str],): +def _generate_label_selector_for_allowed_projects( + allowed_projects: typing.List[str], +): return f"mlrun/project in ({', '.join(allowed_projects)})" diff --git a/mlrun/api/api/endpoints/schedules.py b/mlrun/api/api/endpoints/schedules.py index d672143c40..039393fca7 100644 --- a/mlrun/api/api/endpoints/schedules.py +++ b/mlrun/api/api/endpoints/schedules.py @@ -87,7 +87,9 @@ def list_schedules( db_session: Session = Depends(deps.get_db_session), ): mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, mlrun.api.schemas.AuthorizationAction.read, auth_info, + project, + mlrun.api.schemas.AuthorizationAction.read, + auth_info, ) schedules = get_scheduler().list_schedules( db_session, project, name, kind, labels, include_last_run, include_credentials @@ -95,7 +97,10 @@ def list_schedules( filtered_schedules = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( mlrun.api.schemas.AuthorizationResourceTypes.schedule, schedules.schedules, - lambda schedule: (schedule.project, schedule.name,), + lambda schedule: ( + schedule.project, + schedule.name, + ), auth_info, ) schedules.schedules = filtered_schedules @@ -170,7 +175,10 @@ def delete_schedules( auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): - schedules = get_scheduler().list_schedules(db_session, project,) + schedules = get_scheduler().list_schedules( + db_session, + project, + ) mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions( mlrun.api.schemas.AuthorizationResourceTypes.schedule, schedules.schedules, diff --git a/mlrun/api/api/endpoints/secrets.py b/mlrun/api/api/endpoints/secrets.py index 14b7ca1ae2..bc0c148f47 100644 --- a/mlrun/api/api/endpoints/secrets.py +++ b/mlrun/api/api/endpoints/secrets.py @@ -117,7 +117,9 @@ def list_secrets( @router.post("/user-secrets", status_code=HTTPStatus.CREATED.value) -def add_user_secrets(secrets: schemas.UserSecretCreationRequest,): +def add_user_secrets( + secrets: schemas.UserSecretCreationRequest, +): if secrets.provider != schemas.SecretProviderName.vault: return fastapi.Response( status_code=HTTPStatus.BAD_REQUEST.vault, diff --git a/mlrun/api/api/utils.py b/mlrun/api/api/utils.py index afe198e18a..96069fac70 100644 --- a/mlrun/api/api/utils.py +++ b/mlrun/api/api/utils.py @@ -71,7 +71,9 @@ def get_secrets(auth_info: mlrun.api.schemas.AuthInfo): } -def get_run_db_instance(db_session: Session,): +def get_run_db_instance( + db_session: Session, +): db = get_db() if isinstance(db, SQLDB): run_db = SQLRunDB(db.dsn, db_session) diff --git a/mlrun/api/crud/artifacts.py b/mlrun/api/crud/artifacts.py index f263ff7cef..8c83917bb3 100644 --- a/mlrun/api/crud/artifacts.py +++ b/mlrun/api/crud/artifacts.py @@ -11,7 +11,9 @@ import mlrun.utils.singleton -class Artifacts(metaclass=mlrun.utils.singleton.Singleton,): +class Artifacts( + metaclass=mlrun.utils.singleton.Singleton, +): def store_artifact( self, db_session: sqlalchemy.orm.Session, @@ -33,7 +35,13 @@ def store_artifact( f"key={key}, uid={uid}, data={data}" ) mlrun.api.utils.singletons.db.get_db().store_artifact( - db_session, key, data, uid, iter, tag, project, + db_session, + key, + data, + uid, + iter, + tag, + project, ) def get_artifact( @@ -46,7 +54,11 @@ def get_artifact( ) -> dict: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().read_artifact( - db_session, key, tag, iter, project, + db_session, + key, + tag, + iter, + project, ) def list_artifacts( diff --git a/mlrun/api/crud/client_spec.py b/mlrun/api/crud/client_spec.py index 739bfb7ac7..5163b1ff81 100644 --- a/mlrun/api/crud/client_spec.py +++ b/mlrun/api/crud/client_spec.py @@ -6,7 +6,9 @@ from mlrun.utils import logger -class ClientSpec(metaclass=mlrun.utils.singleton.Singleton,): +class ClientSpec( + metaclass=mlrun.utils.singleton.Singleton, +): def __init__(self): self._cached_nuclio_version = None diff --git a/mlrun/api/crud/feature_store.py b/mlrun/api/crud/feature_store.py index ac0af93713..688b188bd3 100644 --- a/mlrun/api/crud/feature_store.py +++ b/mlrun/api/crud/feature_store.py @@ -11,7 +11,9 @@ import mlrun.utils.singleton -class FeatureStore(metaclass=mlrun.utils.singleton.Singleton,): +class FeatureStore( + metaclass=mlrun.utils.singleton.Singleton, +): def create_feature_set( self, db_session: sqlalchemy.orm.Session, @@ -19,7 +21,12 @@ def create_feature_set( feature_set: mlrun.api.schemas.FeatureSet, versioned: bool = True, ) -> str: - return self._create_object(db_session, project, feature_set, versioned,) + return self._create_object( + db_session, + project, + feature_set, + versioned, + ) def store_feature_set( self, @@ -32,7 +39,13 @@ def store_feature_set( versioned: bool = True, ) -> str: return self._store_object( - db_session, project, name, feature_set, tag, uid, versioned, + db_session, + project, + name, + feature_set, + tag, + uid, + versioned, ) def patch_feature_set( @@ -69,7 +82,9 @@ def get_feature_set( ) def list_feature_sets_tags( - self, db_session: sqlalchemy.orm.Session, project: str, + self, + db_session: sqlalchemy.orm.Session, + project: str, ) -> typing.List[typing.Tuple[str, str, str]]: """ :return: a list of Tuple of (project, feature_set.name, tag) @@ -118,7 +133,12 @@ def delete_feature_set( uid: typing.Optional[str] = None, ): self._delete_object( - db_session, mlrun.api.schemas.FeatureSet, project, name, tag, uid, + db_session, + mlrun.api.schemas.FeatureSet, + project, + name, + tag, + uid, ) def list_features( @@ -132,7 +152,12 @@ def list_features( ) -> mlrun.api.schemas.FeaturesOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_features( - db_session, project, name, tag, entities, labels, + db_session, + project, + name, + tag, + entities, + labels, ) def list_entities( @@ -145,7 +170,11 @@ def list_entities( ) -> mlrun.api.schemas.EntitiesOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_entities( - db_session, project, name, tag, labels, + db_session, + project, + name, + tag, + labels, ) def create_feature_vector( @@ -168,7 +197,13 @@ def store_feature_vector( versioned: bool = True, ) -> str: return self._store_object( - db_session, project, name, feature_vector, tag, uid, versioned, + db_session, + project, + name, + feature_vector, + tag, + uid, + versioned, ) def patch_feature_vector( @@ -201,11 +236,18 @@ def get_feature_vector( uid: typing.Optional[str] = None, ) -> mlrun.api.schemas.FeatureVector: return self._get_object( - db_session, mlrun.api.schemas.FeatureVector, project, name, tag, uid, + db_session, + mlrun.api.schemas.FeatureVector, + project, + name, + tag, + uid, ) def list_feature_vectors_tags( - self, db_session: sqlalchemy.orm.Session, project: str, + self, + db_session: sqlalchemy.orm.Session, + project: str, ) -> typing.List[typing.Tuple[str, str, str]]: """ :return: a list of Tuple of (project, feature_vector.name, tag) @@ -250,7 +292,12 @@ def delete_feature_vector( uid: typing.Optional[str] = None, ): self._delete_object( - db_session, mlrun.api.schemas.FeatureVector, project, name, tag, uid, + db_session, + mlrun.api.schemas.FeatureVector, + project, + name, + tag, + uid, ) def _create_object( @@ -295,11 +342,23 @@ def _store_object( ) if isinstance(object_, mlrun.api.schemas.FeatureSet): return mlrun.api.utils.singletons.db.get_db().store_feature_set( - db_session, project, name, object_, tag, uid, versioned, + db_session, + project, + name, + object_, + tag, + uid, + versioned, ) elif isinstance(object_, mlrun.api.schemas.FeatureVector): return mlrun.api.utils.singletons.db.get_db().store_feature_vector( - db_session, project, name, object_, tag, uid, versioned, + db_session, + project, + name, + object_, + tag, + uid, + versioned, ) else: raise NotImplementedError( @@ -319,15 +378,32 @@ def _patch_object( ) -> str: project = project or mlrun.mlconf.default_project self._validate_identity_for_object_patch( - object_schema.__class__.__name__, object_patch, project, name, tag, uid, + object_schema.__class__.__name__, + object_patch, + project, + name, + tag, + uid, ) if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__: return mlrun.api.utils.singletons.db.get_db().patch_feature_set( - db_session, project, name, object_patch, tag, uid, patch_mode, + db_session, + project, + name, + object_patch, + tag, + uid, + patch_mode, ) elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__: return mlrun.api.utils.singletons.db.get_db().patch_feature_vector( - db_session, project, name, object_patch, tag, uid, patch_mode, + db_session, + project, + name, + object_patch, + tag, + uid, + patch_mode, ) else: raise NotImplementedError( diff --git a/mlrun/api/crud/functions.py b/mlrun/api/crud/functions.py index 2096ee7c0d..5db3e43263 100644 --- a/mlrun/api/crud/functions.py +++ b/mlrun/api/crud/functions.py @@ -11,7 +11,9 @@ import mlrun.utils.singleton -class Functions(metaclass=mlrun.utils.singleton.Singleton,): +class Functions( + metaclass=mlrun.utils.singleton.Singleton, +): def store_function( self, db_session: sqlalchemy.orm.Session, @@ -23,7 +25,12 @@ def store_function( ) -> str: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().store_function( - db_session, function, name, project, tag, versioned, + db_session, + function, + name, + project, + tag, + versioned, ) def get_function( @@ -40,7 +47,10 @@ def get_function( ) def delete_function( - self, db_session: sqlalchemy.orm.Session, project: str, name: str, + self, + db_session: sqlalchemy.orm.Session, + project: str, + name: str, ): return mlrun.api.utils.singletons.db.get_db().delete_function( db_session, project, name @@ -58,5 +68,9 @@ def list_functions( if labels is None: labels = [] return mlrun.api.utils.singletons.db.get_db().list_functions( - db_session, name, project, tag, labels, + db_session, + name, + project, + tag, + labels, ) diff --git a/mlrun/api/crud/logs.py b/mlrun/api/crud/logs.py index bb3591e245..befc6ffcbb 100644 --- a/mlrun/api/crud/logs.py +++ b/mlrun/api/crud/logs.py @@ -14,9 +14,15 @@ from mlrun.runtimes.constants import PodPhases -class Logs(metaclass=mlrun.utils.singleton.Singleton,): +class Logs( + metaclass=mlrun.utils.singleton.Singleton, +): def store_log( - self, body: bytes, project: str, uid: str, append: bool = True, + self, + body: bytes, + project: str, + uid: str, + append: bool = True, ): project = project or mlrun.mlconf.default_project log_file = log_path(project, uid) @@ -26,7 +32,8 @@ def store_log( fp.write(body) def delete_logs( - self, project: str, + self, + project: str, ): project = project or mlrun.mlconf.default_project logs_path = project_logs_path(project) diff --git a/mlrun/api/crud/model_endpoints.py b/mlrun/api/crud/model_endpoints.py index 52f2601aba..9562968475 100644 --- a/mlrun/api/crud/model_endpoints.py +++ b/mlrun/api/crud/model_endpoints.py @@ -134,7 +134,9 @@ def create_or_patch( logger.info("Updating model endpoint", endpoint_id=model_endpoint.metadata.uid) self.write_endpoint_to_kv( - access_key=access_key, endpoint=model_endpoint, update=True, + access_key=access_key, + endpoint=model_endpoint, + update=True, ) logger.info("Model endpoint updated", endpoint_id=model_endpoint.metadata.uid) @@ -241,7 +243,11 @@ def list_endpoints( table_path=path, access_key=auth_info.data_session, filter_expression=self.build_kv_cursor_filter_expression( - project, function, model, labels, top_level, + project, + function, + model, + labels, + top_level, ), attribute_names=["endpoint_id"], raise_for_status=RaiseForStatus.never, @@ -289,7 +295,8 @@ def get_endpoint( """ access_key = self.get_access_key(auth_info) logger.info( - "Getting model endpoint record from kv", endpoint_id=endpoint_id, + "Getting model endpoint record from kv", + endpoint_id=endpoint_id, ) client = get_v3io_client(endpoint=config.v3io_api) @@ -498,7 +505,9 @@ def get_endpoint_metrics( _, container, path = parse_model_endpoint_store_prefix(path) client = get_frames_client( - token=access_key, address=config.v3io_framesd, container=container, + token=access_key, + address=config.v3io_framesd, + container=container, ) metrics_mapping = {} @@ -555,7 +564,10 @@ def delete_model_endpoints_resources(self, project_name: str): endpoints = self.list_endpoints(auth_info, project_name) for endpoint in endpoints.endpoints: self.delete_endpoint_record( - auth_info, endpoint.metadata.project, endpoint.metadata.uid, access_key, + auth_info, + endpoint.metadata.project, + endpoint.metadata.uid, + access_key, ) v3io = get_v3io_client(endpoint=config.v3io_api, access_key=access_key) @@ -568,7 +580,9 @@ def delete_model_endpoints_resources(self, project_name: str): _, container, path = parse_model_endpoint_store_prefix(path) frames = get_frames_client( - token=access_key, container=container, address=config.v3io_framesd, + token=access_key, + container=container, + address=config.v3io_framesd, ) try: all_records = v3io.kv.new_cursor( @@ -603,7 +617,9 @@ def delete_model_endpoints_resources(self, project_name: str): # Cleanup TSDB try: frames.delete( - backend="tsdb", table=path, if_missing=frames_pb2.IGNORE, + backend="tsdb", + table=path, + if_missing=frames_pb2.IGNORE, ) except CreateError: # frames might raise an exception if schema file does not exist. diff --git a/mlrun/api/crud/pipelines.py b/mlrun/api/crud/pipelines.py index ff21ebd92d..cdbc1c2103 100644 --- a/mlrun/api/crud/pipelines.py +++ b/mlrun/api/crud/pipelines.py @@ -18,7 +18,9 @@ from mlrun.utils import logger -class Pipelines(metaclass=mlrun.utils.singleton.Singleton,): +class Pipelines( + metaclass=mlrun.utils.singleton.Singleton, +): def list_pipelines( self, db_session: sqlalchemy.orm.Session, diff --git a/mlrun/api/crud/projects.py b/mlrun/api/crud/projects.py index e7dbf81dfe..96b1dee799 100644 --- a/mlrun/api/crud/projects.py +++ b/mlrun/api/crud/projects.py @@ -112,7 +112,9 @@ def _verify_project_has_no_external_resources(self, project: str): ) def delete_project_resources( - self, session: sqlalchemy.orm.Session, name: str, + self, + session: sqlalchemy.orm.Session, + name: str, ): # Delete schedules before runtime resources - otherwise they will keep getting created mlrun.api.utils.singletons.scheduler.get_scheduler().delete_schedules( @@ -121,7 +123,9 @@ def delete_project_resources( # delete runtime resources mlrun.api.crud.RuntimeResources().delete_runtime_resources( - session, label_selector=f"mlrun/project={name}", force=True, + session, + label_selector=f"mlrun/project={name}", + force=True, ) mlrun.api.crud.Logs().delete_logs(name) @@ -270,7 +274,9 @@ async def _get_project_resources_counters( self._cache["project_resources_counters"]["ttl"] = ttl_time return self._cache["project_resources_counters"]["result"] - async def _calculate_pipelines_counters(self,) -> typing.Dict[str, int]: + async def _calculate_pipelines_counters( + self, + ) -> typing.Dict[str, int]: def _list_pipelines(session): return mlrun.api.crud.Pipelines().list_pipelines( session, "*", format_=mlrun.api.schemas.PipelinesFormat.metadata_only @@ -281,7 +287,8 @@ def _list_pipelines(session): return project_to_running_pipelines_count _, _, pipelines = await fastapi.concurrency.run_in_threadpool( - mlrun.api.db.session.run_function_with_new_db_session, _list_pipelines, + mlrun.api.db.session.run_function_with_new_db_session, + _list_pipelines, ) for pipeline in pipelines: diff --git a/mlrun/api/crud/runs.py b/mlrun/api/crud/runs.py index 26e91c472a..e775bf0ee4 100644 --- a/mlrun/api/crud/runs.py +++ b/mlrun/api/crud/runs.py @@ -15,7 +15,9 @@ from mlrun.utils import logger -class Runs(metaclass=mlrun.utils.singleton.Singleton,): +class Runs( + metaclass=mlrun.utils.singleton.Singleton, +): def store_run( self, db_session: sqlalchemy.orm.Session, @@ -27,7 +29,11 @@ def store_run( project = project or mlrun.mlconf.default_project logger.info("Storing run", data=data) mlrun.api.utils.singletons.db.get_db().store_run( - db_session, data, uid, project, iter=iter, + db_session, + data, + uid, + project, + iter=iter, ) def update_run( diff --git a/mlrun/api/crud/runtime_resources.py b/mlrun/api/crud/runtime_resources.py index 76d7fb2f4e..1aff7fd794 100644 --- a/mlrun/api/crud/runtime_resources.py +++ b/mlrun/api/crud/runtime_resources.py @@ -13,7 +13,9 @@ import mlrun.utils.singleton -class RuntimeResources(metaclass=mlrun.utils.singleton.Singleton,): +class RuntimeResources( + metaclass=mlrun.utils.singleton.Singleton, +): def list_runtime_resources( self, project: str, diff --git a/mlrun/api/crud/secrets.py b/mlrun/api/crud/secrets.py index 380ec9f633..efec87ad3d 100644 --- a/mlrun/api/crud/secrets.py +++ b/mlrun/api/crud/secrets.py @@ -11,7 +11,9 @@ import mlrun.utils.vault -class Secrets(metaclass=mlrun.utils.singleton.Singleton,): +class Secrets( + metaclass=mlrun.utils.singleton.Singleton, +): internal_secrets_key_prefix = "mlrun." # make it a subset of internal since key map are by definition internal key_map_secrets_key_prefix = f"{internal_secrets_key_prefix}map." @@ -203,8 +205,10 @@ def list_secrets( raise mlrun.errors.MLRunAccessDeniedError( "Not allowed to list secrets data from kubernetes provider" ) - secrets_data = mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_data( - project, secrets + secrets_data = ( + mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_data( + project, secrets + ) ) else: @@ -386,7 +390,9 @@ def _validate_and_enrich_secrets_to_store( return secrets_to_store def _get_secret_key_map( - self, project: str, key_map_secret_key: str, + self, + project: str, + key_map_secret_key: str, ) -> typing.Optional[dict]: secrets_data = self.list_secrets( project, diff --git a/mlrun/api/db/base.py b/mlrun/api/db/base.py index 1a3b0415dc..083818ffb7 100644 --- a/mlrun/api/db/base.py +++ b/mlrun/api/db/base.py @@ -30,7 +30,12 @@ def initialize(self, session): @abstractmethod def store_log( - self, session, uid, project="", body=None, append=False, + self, + session, + uid, + project="", + body=None, + append=False, ): pass @@ -40,7 +45,12 @@ def get_log(self, session, uid, project="", offset=0, size=0): @abstractmethod def store_run( - self, session, struct, uid, project="", iter=0, + self, + session, + struct, + uid, + project="", + iter=0, ): pass @@ -85,7 +95,14 @@ def del_runs(self, session, name="", project="", labels=None, state="", days_ago @abstractmethod def store_artifact( - self, session, key, artifact, uid, iter=None, tag="", project="", + self, + session, + key, + artifact, + uid, + iter=None, + tag="", + project="", ): pass @@ -129,7 +146,13 @@ def read_metric(self, session, keys, project="", query=""): @abstractmethod def store_function( - self, session, function, name, project="", tag="", versioned=False, + self, + session, + function, + name, + project="", + tag="", + versioned=False, ) -> str: pass @@ -276,7 +299,11 @@ def delete_project( @abstractmethod def create_feature_set( - self, session, project, feature_set: schemas.FeatureSet, versioned=True, + self, + session, + project, + feature_set: schemas.FeatureSet, + versioned=True, ) -> str: pass @@ -343,7 +370,9 @@ def list_feature_sets( @abstractmethod def list_feature_sets_tags( - self, session, project: str, + self, + session, + project: str, ) -> List[Tuple[str, str, str]]: """ :return: a list of Tuple of (project, feature_set.name, tag) @@ -369,7 +398,11 @@ def delete_feature_set(self, session, project, name, tag=None, uid=None): @abstractmethod def create_feature_vector( - self, session, project, feature_vector: schemas.FeatureVector, versioned=True, + self, + session, + project, + feature_vector: schemas.FeatureVector, + versioned=True, ) -> str: pass @@ -397,7 +430,9 @@ def list_feature_vectors( @abstractmethod def list_feature_vectors_tags( - self, session, project: str, + self, + session, + project: str, ) -> List[Tuple[str, str, str]]: """ :return: a list of Tuple of (project, feature_vector.name, tag) @@ -433,7 +468,12 @@ def patch_feature_vector( @abstractmethod def delete_feature_vector( - self, session, project, name, tag=None, uid=None, + self, + session, + project, + name, + tag=None, + uid=None, ): pass diff --git a/mlrun/api/db/filedb/db.py b/mlrun/api/db/filedb/db.py index 502c2760f2..1597650570 100644 --- a/mlrun/api/db/filedb/db.py +++ b/mlrun/api/db/filedb/db.py @@ -14,7 +14,12 @@ def initialize(self, session): self.db.connect() def store_log( - self, session, uid, project="", body=None, append=False, + self, + session, + uid, + project="", + body=None, + append=False, ): return self._transform_run_db_error( self.db.store_log, uid, project, body, append @@ -24,7 +29,12 @@ def get_log(self, session, uid, project="", offset=0, size=0): return self._transform_run_db_error(self.db.get_log, uid, project, offset, size) def store_run( - self, session, struct, uid, project="", iter=0, + self, + session, + struct, + uid, + project="", + iter=0, ): return self._transform_run_db_error( self.db.store_run, struct, uid, project, iter @@ -87,7 +97,14 @@ def del_runs(self, session, name="", project="", labels=None, state="", days_ago ) def store_artifact( - self, session, key, artifact, uid, iter=None, tag="", project="", + self, + session, + key, + artifact, + uid, + iter=None, + tag="", + project="", ): return self._transform_run_db_error( self.db.store_artifact, key, artifact, uid, iter, tag, project @@ -125,7 +142,13 @@ def del_artifacts(self, session, name="", project="", tag="", labels=None): ) def store_function( - self, session, function, name, project="", tag="", versioned=False, + self, + session, + function, + name, + project="", + tag="", + versioned=False, ) -> str: return self._transform_run_db_error( self.db.store_function, function, name, project, tag, versioned @@ -215,7 +238,11 @@ def delete_project( raise NotImplementedError() def create_feature_set( - self, session, project, feature_set: schemas.FeatureSet, versioned=True, + self, + session, + project, + feature_set: schemas.FeatureSet, + versioned=True, ) -> str: raise NotImplementedError() @@ -276,7 +303,9 @@ def list_feature_sets( raise NotImplementedError() def list_feature_sets_tags( - self, session, project: str, + self, + session, + project: str, ): raise NotImplementedError() @@ -296,7 +325,11 @@ def delete_feature_set(self, session, project, name, tag=None, uid=None): raise NotImplementedError() def create_feature_vector( - self, session, project, feature_vector: schemas.FeatureVector, versioned=True, + self, + session, + project, + feature_vector: schemas.FeatureVector, + versioned=True, ) -> str: raise NotImplementedError() @@ -321,7 +354,9 @@ def list_feature_vectors( raise NotImplementedError() def list_feature_vectors_tags( - self, session, project: str, + self, + session, + project: str, ): raise NotImplementedError() diff --git a/mlrun/api/db/sqldb/db.py b/mlrun/api/db/sqldb/db.py index 1efbd17b17..2d089a272f 100644 --- a/mlrun/api/db/sqldb/db.py +++ b/mlrun/api/db/sqldb/db.py @@ -74,7 +74,12 @@ def initialize(self, session): pass def store_log( - self, session, uid, project="", body=b"", append=False, + self, + session, + uid, + project="", + body=b"", + append=False, ): raise NotImplementedError("DB should not be used for logs storage") @@ -94,7 +99,12 @@ def _list_logs(self, session: Session, project: str): return self._query(session, Log, project=project).all() def store_run( - self, session, run_data, uid, project="", iter=0, + self, + session, + run_data, + uid, + project="", + iter=0, ): logger.debug( "Storing run to db", project=project, uid=uid, iter=iter, run=run_data @@ -271,10 +281,23 @@ def _update_run_state(run_record: Run, run_dict: dict): run_dict.setdefault("status", {})["state"] = state def store_artifact( - self, session, key, artifact, uid, iter=None, tag="", project="", + self, + session, + key, + artifact, + uid, + iter=None, + tag="", + project="", ): self._store_artifact( - session, key, artifact, uid, iter, tag, project, + session, + key, + artifact, + uid, + iter, + tag, + project, ) def _store_artifact( @@ -493,7 +516,13 @@ def del_artifacts(self, session, name="", project="", tag="*", labels=None): self.del_artifact(session, key, "", project) def store_function( - self, session, function, name, project="", tag="", versioned=False, + self, + session, + function, + name, + project="", + tag="", + versioned=False, ) -> str: logger.debug( "Storing function to DB", @@ -533,7 +562,11 @@ def store_function( function.setdefault("metadata", {})["name"] = name fn = self._get_class_instance_by_uid(session, Function, name, project, uid) if not fn: - fn = Function(name=name, project=project, uid=uid,) + fn = Function( + name=name, + project=project, + uid=uid, + ) fn.updated = updated labels = get_in(function, "metadata.labels", {}) update_labels(fn, labels) @@ -843,7 +876,12 @@ def _list_project_feature_vector_names( def tag_artifacts(self, session, artifacts, project: str, name: str): for artifact in artifacts: query = ( - self._query(session, artifact.Tag, project=project, name=name,) + self._query( + session, + artifact.Tag, + project=project, + name=name, + ) .join(Artifact) .filter(Artifact.key == artifact.key) ) @@ -995,7 +1033,10 @@ async def get_project_resources_counters( project_to_schedule_count, project_to_feature_set_count, project_to_models_count, - (project_to_recent_failed_runs_count, project_to_running_runs_count,), + ( + project_to_recent_failed_runs_count, + project_to_running_runs_count, + ), ) = results return ( project_to_files_count, @@ -1165,7 +1206,9 @@ def _patch_project_record_from_project( # If a bad kind value was passed, it will fail here (return 422 to caller) project = schemas.Project(**project_record_full_object) self.store_project( - session, name, project, + session, + name, + project, ) project_record.full_object = project_record_full_object @@ -1252,7 +1295,13 @@ def _verify_empty_list_of_project_related_resources( ) def _get_record_by_name_tag_and_uid( - self, session, cls, project: str, name: str, tag: str = None, uid: str = None, + self, + session, + cls, + project: str, + name: str, + tag: str = None, + uid: str = None, ): query = self._query(session, cls, name=name, project=project) computed_tag = tag or "latest" @@ -1269,7 +1318,12 @@ def _get_record_by_name_tag_and_uid( return computed_tag, object_tag_uid, query.one_or_none() def _get_feature_set( - self, session, project: str, name: str, tag: str = None, uid: str = None, + self, + session, + project: str, + name: str, + tag: str = None, + uid: str = None, ): ( computed_tag, @@ -1289,7 +1343,12 @@ def _get_feature_set( return None def get_feature_set( - self, session, project: str, name: str, tag: str = None, uid: str = None, + self, + session, + project: str, + name: str, + tag: str = None, + uid: str = None, ) -> schemas.FeatureSet: feature_set = self._get_feature_set(session, project, name, tag, uid) if not feature_set: @@ -1342,7 +1401,8 @@ def _generate_feature_set_digest(feature_set: schemas.FeatureSet): return schemas.FeatureSetDigestOutput( metadata=feature_set.metadata, spec=schemas.FeatureSetDigestSpec( - entities=feature_set.spec.entities, features=feature_set.spec.features, + entities=feature_set.spec.entities, + features=feature_set.spec.features, ), ) @@ -1595,7 +1655,9 @@ def list_feature_sets( return schemas.FeatureSetsOutput(feature_sets=feature_sets) def list_feature_sets_tags( - self, session, project: str, + self, + session, + project: str, ): query = ( session.query(FeatureSet.name, FeatureSet.Tag.name) @@ -1651,7 +1713,10 @@ def _update_feature_set_spec( @staticmethod def _common_object_validate_and_perform_uid_change( - object_dict: dict, tag, versioned, existing_uid=None, + object_dict: dict, + tag, + versioned, + existing_uid=None, ): uid = fill_object_hash(object_dict, "uid", tag) if not versioned: @@ -1667,7 +1732,9 @@ def _common_object_validate_and_perform_uid_change( @staticmethod def _update_db_record_from_object_dict( - db_object, common_object_dict: dict, uid, + db_object, + common_object_dict: dict, + uid, ): db_object.name = common_object_dict["metadata"]["name"] updated_datetime = datetime.now(timezone.utc) @@ -1742,7 +1809,12 @@ def store_feature_set( return uid def _validate_and_enrich_record_for_creation( - self, session, new_object, db_class, project, versioned, + self, + session, + new_object, + db_class, + project, + versioned, ): object_type = new_object.__class__.__name__ @@ -1769,7 +1841,11 @@ def _validate_and_enrich_record_for_creation( return uid, new_object.metadata.tag, object_dict def create_feature_set( - self, session, project, feature_set: schemas.FeatureSet, versioned=True, + self, + session, + project, + feature_set: schemas.FeatureSet, + versioned=True, ) -> str: (uid, tag, feature_set_dict,) = self._validate_and_enrich_record_for_creation( session, feature_set, FeatureSet, project, versioned @@ -1858,7 +1934,11 @@ def delete_feature_set(self, session, project, name, tag=None, uid=None): self._delete_feature_store_object(session, FeatureSet, project, name, tag, uid) def create_feature_vector( - self, session, project, feature_vector: schemas.FeatureVector, versioned=True, + self, + session, + project, + feature_vector: schemas.FeatureVector, + versioned=True, ) -> str: ( uid, @@ -1880,7 +1960,12 @@ def create_feature_vector( return uid def _get_feature_vector( - self, session, project: str, name: str, tag: str = None, uid: str = None, + self, + session, + project: str, + name: str, + tag: str = None, + uid: str = None, ): ( computed_tag, @@ -1969,7 +2054,9 @@ def list_feature_vectors( return schemas.FeatureVectorsOutput(feature_vectors=feature_vectors) def list_feature_vectors_tags( - self, session, project: str, + self, + session, + project: str, ): query = ( session.query(FeatureVector.name, FeatureVector.Tag.name) @@ -2212,7 +2299,10 @@ def _latest_uid_filter(self, session, query): Artifact.key, func.max(Artifact.updated), ) - .group_by(Artifact.project, Artifact.key.label("key"),) + .group_by( + Artifact.project, + Artifact.key.label("key"), + ) .subquery("max_key") ) @@ -2465,7 +2555,8 @@ def _delete_class_labels( session.commit() def _transform_schedule_record_to_scheme( - self, schedule_record: Schedule, + self, + schedule_record: Schedule, ) -> schemas.ScheduleRecord: schedule = schemas.ScheduleRecord.from_orm(schedule_record) schedule.creation_time = self._add_utc_timezone(schedule.creation_time) @@ -2483,7 +2574,8 @@ def _add_utc_timezone(time_value: typing.Optional[datetime]): @staticmethod def _transform_feature_set_model_to_schema( - feature_set_record: FeatureSet, tag=None, + feature_set_record: FeatureSet, + tag=None, ) -> schemas.FeatureSet: feature_set_full_dict = feature_set_record.full_object feature_set_resp = schemas.FeatureSet(**feature_set_full_dict) @@ -2493,7 +2585,8 @@ def _transform_feature_set_model_to_schema( @staticmethod def _transform_feature_vector_model_to_schema( - feature_vector_record: FeatureVector, tag=None, + feature_vector_record: FeatureVector, + tag=None, ) -> schemas.FeatureVector: feature_vector_full_dict = feature_vector_record.full_object feature_vector_resp = schemas.FeatureVector(**feature_vector_full_dict) @@ -2509,13 +2602,16 @@ def _transform_project_record_to_schema( if not project_record.full_object: project = schemas.Project( metadata=schemas.ProjectMetadata( - name=project_record.name, created=project_record.created, + name=project_record.name, + created=project_record.created, ), spec=schemas.ProjectSpec( description=project_record.description, source=project_record.source, ), - status=schemas.ObjectStatus(state=project_record.state,), + status=schemas.ObjectStatus( + state=project_record.state, + ), ) self.store_project(session, project_record.name, project) return project @@ -2652,7 +2748,10 @@ def create_marketplace_source( ) def store_marketplace_source( - self, session, name, ordered_source: schemas.IndexedMarketplaceSource, + self, + session, + name, + ordered_source: schemas.IndexedMarketplaceSource, ): logger.debug( "Storing marketplace source in DB", index=ordered_source.index, name=name @@ -2744,7 +2843,8 @@ def get_current_data_version( def create_data_version(self, session, version): logger.debug( - "Creating data version in DB", version=version, + "Creating data version in DB", + version=version, ) now = datetime.now(timezone.utc) diff --git a/mlrun/api/initial_data.py b/mlrun/api/initial_data.py index 56de885081..2ed011a7a8 100644 --- a/mlrun/api/initial_data.py +++ b/mlrun/api/initial_data.py @@ -198,7 +198,8 @@ def _add_initial_data(db_session: sqlalchemy.orm.Session): def _fix_datasets_large_previews( - db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session, + db: mlrun.api.db.sqldb.db.SQLDB, + db_session: sqlalchemy.orm.Session, ): logger.info("Fixing datasets large previews") # get all artifacts @@ -273,7 +274,8 @@ def _fix_datasets_large_previews( ) except Exception as exc: logger.warning( - "Failed fixing dataset artifact large preview. Continuing", exc=exc, + "Failed fixing dataset artifact large preview. Continuing", + exc=exc, ) @@ -472,7 +474,8 @@ def _add_data_version( if db.get_current_data_version(db_session, raise_on_not_found=False) is None: data_version = _resolve_current_data_version(db, db_session) logger.info( - "No data version, setting data version", data_version=data_version, + "No data version, setting data version", + data_version=data_version, ) db.create_data_version(db_session, data_version) diff --git a/mlrun/api/migrations_sqlite/versions/11f8dd2dc9fe_init.py b/mlrun/api/migrations_sqlite/versions/11f8dd2dc9fe_init.py index 9017b0f849..e361737f58 100644 --- a/mlrun/api/migrations_sqlite/versions/11f8dd2dc9fe_init.py +++ b/mlrun/api/migrations_sqlite/versions/11f8dd2dc9fe_init.py @@ -189,7 +189,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["artifacts.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["artifacts.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_artifacts_labels_uc"), ) @@ -207,7 +210,10 @@ def upgrade(): nullable=True, ), sa.Column("obj_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["obj_id"], ["artifacts.id"],), + sa.ForeignKeyConstraint( + ["obj_id"], + ["artifacts.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("project", "name", "obj_id", name="_artifacts_tags_uc"), ) @@ -225,7 +231,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["functions.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["functions.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_functions_labels_uc"), ) @@ -248,8 +257,14 @@ def upgrade(): sa.String(255, collation=SQLCollationUtil.collation()), nullable=True, ), - sa.ForeignKeyConstraint(["obj_id"], ["functions.id"],), - sa.ForeignKeyConstraint(["obj_name"], ["functions.name"],), + sa.ForeignKeyConstraint( + ["obj_id"], + ["functions.id"], + ), + sa.ForeignKeyConstraint( + ["obj_name"], + ["functions.name"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("project", "name", "obj_name", name="_functions_tags_uc"), ) @@ -257,8 +272,14 @@ def upgrade(): "project_users", sa.Column("project_id", sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["project_id"], ["projects.id"],), - sa.ForeignKeyConstraint(["user_id"], ["users.id"],), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), ) op.create_table( "runs_labels", @@ -274,7 +295,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["runs.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["runs.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_runs_labels_uc"), ) @@ -292,7 +316,10 @@ def upgrade(): nullable=True, ), sa.Column("obj_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["obj_id"], ["runs.id"],), + sa.ForeignKeyConstraint( + ["obj_id"], + ["runs.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("project", "name", "obj_id", name="_runs_tags_uc"), ) diff --git a/mlrun/api/migrations_sqlite/versions/2b6d23c715aa_adding_feature_sets.py b/mlrun/api/migrations_sqlite/versions/2b6d23c715aa_adding_feature_sets.py index eaf82ae615..5860deb4cb 100644 --- a/mlrun/api/migrations_sqlite/versions/2b6d23c715aa_adding_feature_sets.py +++ b/mlrun/api/migrations_sqlite/versions/2b6d23c715aa_adding_feature_sets.py @@ -60,7 +60,10 @@ def upgrade(): sa.String(255, collation=SQLCollationUtil.collation()), nullable=True, ), - sa.ForeignKeyConstraint(["feature_set_id"], ["feature_sets.id"],), + sa.ForeignKeyConstraint( + ["feature_set_id"], + ["feature_sets.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -77,7 +80,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["feature_sets.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["feature_sets.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_feature_sets_labels_uc"), ) @@ -100,8 +106,14 @@ def upgrade(): sa.String(255, collation=SQLCollationUtil.collation()), nullable=True, ), - sa.ForeignKeyConstraint(["obj_id"], ["feature_sets.id"],), - sa.ForeignKeyConstraint(["obj_name"], ["feature_sets.name"],), + sa.ForeignKeyConstraint( + ["obj_id"], + ["feature_sets.id"], + ), + sa.ForeignKeyConstraint( + ["obj_name"], + ["feature_sets.name"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "project", "name", "obj_name", name="_feature_sets_tags_uc" @@ -121,7 +133,10 @@ def upgrade(): sa.String(255, collation=SQLCollationUtil.collation()), nullable=True, ), - sa.ForeignKeyConstraint(["feature_set_id"], ["feature_sets.id"],), + sa.ForeignKeyConstraint( + ["feature_set_id"], + ["feature_sets.id"], + ), sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### diff --git a/mlrun/api/migrations_sqlite/versions/b68e8e897a28_schedule_labels.py b/mlrun/api/migrations_sqlite/versions/b68e8e897a28_schedule_labels.py index 8110f1a8a3..d220c8783d 100644 --- a/mlrun/api/migrations_sqlite/versions/b68e8e897a28_schedule_labels.py +++ b/mlrun/api/migrations_sqlite/versions/b68e8e897a28_schedule_labels.py @@ -33,7 +33,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["schedules_v2.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["schedules_v2.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_schedules_v2_labels_uc"), ) diff --git a/mlrun/api/migrations_sqlite/versions/bcd0c1f9720c_adding_project_labels.py b/mlrun/api/migrations_sqlite/versions/bcd0c1f9720c_adding_project_labels.py index a89d1ed405..802ba958df 100644 --- a/mlrun/api/migrations_sqlite/versions/bcd0c1f9720c_adding_project_labels.py +++ b/mlrun/api/migrations_sqlite/versions/bcd0c1f9720c_adding_project_labels.py @@ -33,7 +33,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["projects.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["projects.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_projects_labels_uc"), ) diff --git a/mlrun/api/migrations_sqlite/versions/f4249b4ba6fa_adding_feature_vectors.py b/mlrun/api/migrations_sqlite/versions/f4249b4ba6fa_adding_feature_vectors.py index 43bb57aa24..f70cf85a19 100644 --- a/mlrun/api/migrations_sqlite/versions/f4249b4ba6fa_adding_feature_vectors.py +++ b/mlrun/api/migrations_sqlite/versions/f4249b4ba6fa_adding_feature_vectors.py @@ -60,7 +60,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["feature_vectors.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["feature_vectors.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_feature_vectors_labels_uc"), ) @@ -83,8 +86,14 @@ def upgrade(): sa.String(255, collation=SQLCollationUtil.collation()), nullable=True, ), - sa.ForeignKeyConstraint(["obj_id"], ["feature_vectors.id"],), - sa.ForeignKeyConstraint(["obj_name"], ["feature_vectors.name"],), + sa.ForeignKeyConstraint( + ["obj_id"], + ["feature_vectors.id"], + ), + sa.ForeignKeyConstraint( + ["obj_name"], + ["feature_vectors.name"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "project", "name", "obj_name", name="_feature_vectors_tags_uc" diff --git a/mlrun/api/migrations_sqlite/versions/f7b5a1a03629_adding_feature_labels.py b/mlrun/api/migrations_sqlite/versions/f7b5a1a03629_adding_feature_labels.py index 3c5e5d89be..d2888e8278 100644 --- a/mlrun/api/migrations_sqlite/versions/f7b5a1a03629_adding_feature_labels.py +++ b/mlrun/api/migrations_sqlite/versions/f7b5a1a03629_adding_feature_labels.py @@ -33,7 +33,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["entities.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["entities.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_entities_labels_uc"), ) @@ -51,7 +54,10 @@ def upgrade(): nullable=True, ), sa.Column("parent", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["parent"], ["features.id"],), + sa.ForeignKeyConstraint( + ["parent"], + ["features.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name", "parent", name="_features_labels_uc"), ) diff --git a/mlrun/api/schemas/auth.py b/mlrun/api/schemas/auth.py index e34178e0b0..e864b01845 100644 --- a/mlrun/api/schemas/auth.py +++ b/mlrun/api/schemas/auth.py @@ -44,7 +44,9 @@ class AuthorizationResourceTypes(str, enum.Enum): marketplace_source = "marketplace-source" def to_resource_string( - self, project_name: str, resource_name: str, + self, + project_name: str, + resource_name: str, ): return { # project is the resource itself, so no need for both resource_name and project_name diff --git a/mlrun/api/schemas/model_endpoints.py b/mlrun/api/schemas/model_endpoints.py index dd8813b3c5..b8ff621385 100644 --- a/mlrun/api/schemas/model_endpoints.py +++ b/mlrun/api/schemas/model_endpoints.py @@ -116,7 +116,8 @@ def __init__(self, **data: Any): super().__init__(**data) if self.metadata.uid is None: uid = create_model_endpoint_id( - function_uri=self.spec.function_uri, versioned_model=self.spec.model, + function_uri=self.spec.function_uri, + versioned_model=self.spec.model, ) self.metadata.uid = str(uid) diff --git a/mlrun/api/utils/auth/verifier.py b/mlrun/api/utils/auth/verifier.py index e38bbf90e5..f4f6d3d74a 100644 --- a/mlrun/api/utils/auth/verifier.py +++ b/mlrun/api/utils/auth/verifier.py @@ -119,7 +119,11 @@ def query_global_resource_permissions( raise_on_forbidden: bool = True, ) -> bool: return self.query_resource_permissions( - resource_type, "", action, auth_info, raise_on_forbidden, + resource_type, + "", + action, + auth_info, + raise_on_forbidden, ) def query_resource_permissions( @@ -145,7 +149,10 @@ def query_permissions( raise_on_forbidden: bool = True, ) -> bool: return self._auth_provider.query_permissions( - resource, action, auth_info, raise_on_forbidden, + resource, + action, + auth_info, + raise_on_forbidden, ) def filter_by_permissions( @@ -156,7 +163,10 @@ def filter_by_permissions( auth_info: mlrun.api.schemas.AuthInfo, ) -> typing.List: return self._auth_provider.filter_by_permissions( - resources, opa_resource_extractor, action, auth_info, + resources, + opa_resource_extractor, + action, + auth_info, ) def add_allowed_project_for_owner( diff --git a/mlrun/api/utils/background_tasks.py b/mlrun/api/utils/background_tasks.py index bb3d008b13..c58fdb8db9 100644 --- a/mlrun/api/utils/background_tasks.py +++ b/mlrun/api/utils/background_tasks.py @@ -41,7 +41,11 @@ def create_project_background_task( return self.get_project_background_task(project, name) def create_background_task( - self, background_tasks: fastapi.BackgroundTasks, function, *args, **kwargs, + self, + background_tasks: fastapi.BackgroundTasks, + function, + *args, + **kwargs, ) -> mlrun.api.schemas.BackgroundTask: name = str(uuid.uuid4()) # sanity @@ -69,7 +73,9 @@ def _generate_background_task( ) def get_project_background_task( - self, project: str, name: str, + self, + project: str, + name: str, ) -> mlrun.api.schemas.BackgroundTask: if ( project in self._project_background_tasks @@ -79,7 +85,10 @@ def get_project_background_task( else: return self._generate_background_task_not_found_response(name, project) - def get_background_task(self, name: str,) -> mlrun.api.schemas.BackgroundTask: + def get_background_task( + self, + name: str, + ) -> mlrun.api.schemas.BackgroundTask: if name in self._background_tasks: return self._background_tasks[name] else: diff --git a/mlrun/api/utils/clients/iguazio.py b/mlrun/api/utils/clients/iguazio.py index 42290972b6..7f57d1f938 100644 --- a/mlrun/api/utils/clients/iguazio.py +++ b/mlrun/api/utils/clients/iguazio.py @@ -165,7 +165,10 @@ def create_project( return self._create_project_in_iguazio(session, body, wait_for_completion) def update_project( - self, session: str, name: str, project: mlrun.api.schemas.Project, + self, + session: str, + name: str, + project: mlrun.api.schemas.Project, ): logger.debug("Updating project in Iguazio", name=name, project=project) body = self._transform_mlrun_project_to_iguazio_project(project) @@ -254,11 +257,17 @@ def list_projects( latest_updated_at = self._find_latest_updated_at(response_body) return projects, latest_updated_at - def get_project(self, session: str, name: str,) -> mlrun.api.schemas.Project: + def get_project( + self, + session: str, + name: str, + ) -> mlrun.api.schemas.Project: return self._get_project_from_iguazio(session, name) def get_project_owner( - self, session: str, name: str, + self, + session: str, + name: str, ) -> mlrun.api.schemas.ProjectOwner: response = self._get_project_from_iguazio_without_parsing( session, name, include_owner_session=True @@ -541,12 +550,16 @@ def _transform_iguazio_project_to_mlrun_project( "description" ] if iguazio_project["attributes"].get("labels"): - mlrun_project.metadata.labels = Client._transform_iguazio_labels_to_mlrun_labels( - iguazio_project["attributes"]["labels"] + mlrun_project.metadata.labels = ( + Client._transform_iguazio_labels_to_mlrun_labels( + iguazio_project["attributes"]["labels"] + ) ) if iguazio_project["attributes"].get("annotations"): - mlrun_project.metadata.annotations = Client._transform_iguazio_labels_to_mlrun_labels( - iguazio_project["attributes"]["annotations"] + mlrun_project.metadata.annotations = ( + Client._transform_iguazio_labels_to_mlrun_labels( + iguazio_project["attributes"]["annotations"] + ) ) if iguazio_project["attributes"].get("owner_username"): mlrun_project.spec.owner = iguazio_project["attributes"]["owner_username"] diff --git a/mlrun/api/utils/db/backup.py b/mlrun/api/utils/db/backup.py index bc4fb1bb06..1a808e75d4 100644 --- a/mlrun/api/utils/db/backup.py +++ b/mlrun/api/utils/db/backup.py @@ -98,7 +98,8 @@ def _load_database_backup_mysql(self, backup_file_name: str) -> None: backup_path = self._get_backup_file_path(backup_file_name) logger.debug( - "Loading mysql DB backup data", backup_path=backup_path, + "Loading mysql DB backup data", + backup_path=backup_path, ) dsn_data = mlrun.api.utils.db.mysql.MySQLUtil.get_mysql_dsn_data() self._run_shell_command( @@ -175,7 +176,8 @@ def _get_sqlite_db_file_path() -> str: @staticmethod def _run_shell_command(command: str) -> int: logger.debug( - "Running shell command", command=command, + "Running shell command", + command=command, ) process = subprocess.Popen( command, diff --git a/mlrun/api/utils/projects/follower.py b/mlrun/api/utils/projects/follower.py index 54bf738f5e..00913efa27 100644 --- a/mlrun/api/utils/projects/follower.py +++ b/mlrun/api/utils/projects/follower.py @@ -104,7 +104,8 @@ def ensure_project( ) if is_project_created: mlrun.api.utils.auth.verifier.AuthVerifier().add_allowed_project_for_owner( - name, auth_info, + name, + auth_info, ) return is_project_created @@ -209,7 +210,10 @@ def delete_project( ) else: return self._leader_client.delete_project( - auth_info.session, name, deletion_strategy, wait_for_completion, + auth_info.session, + name, + deletion_strategy, + wait_for_completion, ) return False @@ -224,7 +228,9 @@ def get_project( return self._projects[name] def get_project_owner( - self, db_session: sqlalchemy.orm.Session, name: str, + self, + db_session: sqlalchemy.orm.Session, + name: str, ) -> mlrun.api.schemas.ProjectOwner: return self._leader_client.get_project_owner(self._sync_session, name) diff --git a/mlrun/api/utils/projects/leader.py b/mlrun/api/utils/projects/leader.py index df481afa86..042f583d92 100644 --- a/mlrun/api/utils/projects/leader.py +++ b/mlrun/api/utils/projects/leader.py @@ -147,7 +147,9 @@ async def get_project_summary( return await self._leader_follower.get_project_summary(db_session, name) def get_project_owner( - self, db_session: sqlalchemy.orm.Session, name: str, + self, + db_session: sqlalchemy.orm.Session, + name: str, ) -> mlrun.api.schemas.ProjectOwner: raise NotImplementedError() @@ -288,7 +290,8 @@ def _ensure_project_synced( try: self._enrich_and_validate_before_creation(project) self._followers[missing_follower].create_project( - db_session, project, + db_session, + project, ) except Exception as exc: logger.warning( diff --git a/mlrun/api/utils/projects/member.py b/mlrun/api/utils/projects/member.py index 9f3096a988..9906efea96 100644 --- a/mlrun/api/utils/projects/member.py +++ b/mlrun/api/utils/projects/member.py @@ -141,6 +141,8 @@ async def list_project_summaries( @abc.abstractmethod def get_project_owner( - self, db_session: sqlalchemy.orm.Session, name: str, + self, + db_session: sqlalchemy.orm.Session, + name: str, ) -> mlrun.api.schemas.ProjectOwner: pass diff --git a/mlrun/api/utils/projects/remotes/leader.py b/mlrun/api/utils/projects/remotes/leader.py index 8be4a90d5d..caa9639a6b 100644 --- a/mlrun/api/utils/projects/remotes/leader.py +++ b/mlrun/api/utils/projects/remotes/leader.py @@ -17,7 +17,10 @@ def create_project( @abc.abstractmethod def update_project( - self, session: str, name: str, project: mlrun.api.schemas.Project, + self, + session: str, + name: str, + project: mlrun.api.schemas.Project, ): pass @@ -33,14 +36,20 @@ def delete_project( @abc.abstractmethod def list_projects( - self, session: str, updated_after: typing.Optional[datetime.datetime] = None, + self, + session: str, + updated_after: typing.Optional[datetime.datetime] = None, ) -> typing.Tuple[ typing.List[mlrun.api.schemas.Project], typing.Optional[datetime.datetime] ]: pass @abc.abstractmethod - def get_project(self, session: str, name: str,) -> mlrun.api.schemas.Project: + def get_project( + self, + session: str, + name: str, + ) -> mlrun.api.schemas.Project: pass @abc.abstractmethod @@ -51,6 +60,8 @@ def format_as_leader_project( @abc.abstractmethod def get_project_owner( - self, session: str, name: str, + self, + session: str, + name: str, ) -> mlrun.api.schemas.ProjectOwner: pass diff --git a/mlrun/api/utils/projects/remotes/nop_leader.py b/mlrun/api/utils/projects/remotes/nop_leader.py index 0c81eaa187..89cc7b50cd 100644 --- a/mlrun/api/utils/projects/remotes/nop_leader.py +++ b/mlrun/api/utils/projects/remotes/nop_leader.py @@ -30,7 +30,10 @@ def create_project( return is_running_in_background def update_project( - self, session: str, name: str, project: mlrun.api.schemas.Project, + self, + session: str, + name: str, + project: mlrun.api.schemas.Project, ): self._update_state(project) mlrun.api.utils.singletons.project_member.get_project_member().store_project( @@ -55,7 +58,9 @@ def delete_project( ) def list_projects( - self, session: str, updated_after: typing.Optional[datetime.datetime] = None, + self, + session: str, + updated_after: typing.Optional[datetime.datetime] = None, ) -> typing.Tuple[ typing.List[mlrun.api.schemas.Project], typing.Optional[datetime.datetime] ]: @@ -66,9 +71,15 @@ def list_projects( datetime.datetime.utcnow(), ) - def get_project(self, session: str, name: str,) -> mlrun.api.schemas.Project: - return mlrun.api.utils.singletons.project_member.get_project_member().get_project( - self.db_session, name + def get_project( + self, + session: str, + name: str, + ) -> mlrun.api.schemas.Project: + return ( + mlrun.api.utils.singletons.project_member.get_project_member().get_project( + self.db_session, name + ) ) def format_as_leader_project( @@ -77,7 +88,9 @@ def format_as_leader_project( return mlrun.api.schemas.IguazioProject(data=project.dict()) def get_project_owner( - self, session: str, name: str, + self, + session: str, + name: str, ) -> mlrun.api.schemas.ProjectOwner: project = self.get_project(session, name) return mlrun.api.schemas.ProjectOwner( diff --git a/mlrun/api/utils/scheduler.py b/mlrun/api/utils/scheduler.py index 7a59e3155c..83457160a8 100644 --- a/mlrun/api/utils/scheduler.py +++ b/mlrun/api/utils/scheduler.py @@ -194,16 +194,24 @@ def get_schedule( ) def delete_schedule( - self, db_session: Session, project: str, name: str, + self, + db_session: Session, + project: str, + name: str, ): logger.debug("Deleting schedule", project=project, name=name) self._remove_schedule_scheduler_resources(project, name) get_db().delete_schedule(db_session, project, name) def delete_schedules( - self, db_session: Session, project: str, + self, + db_session: Session, + project: str, ): - schedules = self.list_schedules(db_session, project,) + schedules = self.list_schedules( + db_session, + project, + ) logger.debug("Deleting schedules", project=project) for schedule in schedules.schedules: self._remove_schedule_scheduler_resources(schedule.project, schedule.name) @@ -242,7 +250,9 @@ async def invoke_schedule( return await function(*args, **kwargs) def _ensure_auth_info_has_access_key( - self, auth_info: mlrun.api.schemas.AuthInfo, kind: schemas.ScheduleKinds, + self, + auth_info: mlrun.api.schemas.AuthInfo, + kind: schemas.ScheduleKinds, ): if ( kind not in schemas.ScheduleKinds.local_kinds() @@ -252,12 +262,17 @@ def _ensure_auth_info_has_access_key( or auth_info.access_key == mlrun.model.Credentials.generate_access_key ) ): - auth_info.access_key = mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key( - auth_info.session + auth_info.access_key = ( + mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key( + auth_info.session + ) ) def _store_schedule_secrets( - self, auth_info: mlrun.api.schemas.AuthInfo, project: str, name: str, + self, + auth_info: mlrun.api.schemas.AuthInfo, + project: str, + name: str, ): # import here to avoid circular imports import mlrun.api.crud @@ -268,8 +283,8 @@ def _store_schedule_secrets( raise mlrun.errors.MLRunAccessDeniedError( "Access key is required to create schedules in OPA authorization mode" ) - access_key_secret_key = mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key( - name + access_key_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key(name) ) # schedule name may be an invalid secret key, therefore we're using the key map feature of our secrets # handler @@ -280,29 +295,34 @@ def _store_schedule_secrets( access_key_secret_key: auth_info.access_key, } if auth_info.username: - username_secret_key = mlrun.api.crud.Secrets().generate_schedule_username_secret_key( - name + username_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_username_secret_key(name) ) secrets[username_secret_key] = auth_info.username mlrun.api.crud.Secrets().store_secrets( project, - schemas.SecretsData(provider=self._secrets_provider, secrets=secrets,), + schemas.SecretsData( + provider=self._secrets_provider, + secrets=secrets, + ), allow_internal_secrets=True, key_map_secret_key=secret_key_map, ) def _remove_schedule_secrets( - self, project: str, name: str, + self, + project: str, + name: str, ): # import here to avoid circular imports import mlrun.api.crud if mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required(): - access_key_secret_key = mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key( - name + access_key_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key(name) ) - username_secret_key = mlrun.api.crud.Secrets().generate_schedule_username_secret_key( - name + username_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_username_secret_key(name) ) secret_key_map = ( mlrun.api.crud.Secrets().generate_schedule_key_map_secret_key() @@ -331,8 +351,8 @@ def _get_schedule_secrets( # import here to avoid circular imports import mlrun.api.crud - schedule_access_key_secret_key = mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key( - name + schedule_access_key_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key(name) ) secret_key_map = mlrun.api.crud.Secrets().generate_schedule_key_map_secret_key() # TODO: support listing (and not only get) secrets using key map @@ -346,8 +366,8 @@ def _get_schedule_secrets( ) username = None if include_username: - schedule_username_secret_key = mlrun.api.crud.Secrets().generate_schedule_username_secret_key( - name + schedule_username_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_username_secret_key(name) ) username = mlrun.api.crud.Secrets().get_secret( project, @@ -370,8 +390,10 @@ def _validate_cron_trigger( Enforce no more then one job per min_allowed_interval """ logger.debug("Validating cron trigger") - apscheduler_cron_trigger = self.transform_schemas_cron_trigger_to_apscheduler_cron_trigger( - cron_trigger + apscheduler_cron_trigger = ( + self.transform_schemas_cron_trigger_to_apscheduler_cron_trigger( + cron_trigger + ) ) now = now or datetime.now(apscheduler_cron_trigger.timezone) next_run_time = None @@ -675,7 +697,10 @@ async def submit_run_wrapper( run_metadata["project"], run_metadata["uid"], run_metadata["iteration"] ) get_db().update_schedule( - db_session, run_metadata["project"], schedule_name, last_run_uri=run_uri, + db_session, + run_metadata["project"], + schedule_name, + last_run_uri=run_uri, ) close_session(db_session) diff --git a/mlrun/api/utils/singletons/scheduler.py b/mlrun/api/utils/singletons/scheduler.py index b15dd60879..61583f8a9e 100644 --- a/mlrun/api/utils/singletons/scheduler.py +++ b/mlrun/api/utils/singletons/scheduler.py @@ -11,7 +11,9 @@ async def initialize_scheduler(): db_session = None try: db_session = create_session() - await scheduler.start(db_session,) + await scheduler.start( + db_session, + ) finally: db_session.close() diff --git a/mlrun/artifacts/base.py b/mlrun/artifacts/base.py index 06c41b1797..342f1c8e13 100644 --- a/mlrun/artifacts/base.py +++ b/mlrun/artifacts/base.py @@ -232,7 +232,10 @@ def blob_hash(data): def upload_extra_data( - artifact_spec: Artifact, extra_data: dict, prefix="", update_spec=False, + artifact_spec: Artifact, + extra_data: dict, + prefix="", + update_spec=False, ): if not extra_data: return diff --git a/mlrun/artifacts/manager.py b/mlrun/artifacts/manager.py index 4014383317..9ad65966a9 100644 --- a/mlrun/artifacts/manager.py +++ b/mlrun/artifacts/manager.py @@ -57,7 +57,9 @@ def dict_to_artifact(struct: dict): class ArtifactManager: def __init__( - self, db: RunDBInterface = None, calc_hash=True, + self, + db: RunDBInterface = None, + calc_hash=True, ): self.calc_hash = calc_hash diff --git a/mlrun/artifacts/plots.py b/mlrun/artifacts/plots.py index ed22e79164..b3d97c1cee 100644 --- a/mlrun/artifacts/plots.py +++ b/mlrun/artifacts/plots.py @@ -135,7 +135,10 @@ class BokehArtifact(Artifact): kind = "bokeh" def __init__( - self, figure, key: str = None, target_path: str = None, + self, + figure, + key: str = None, + target_path: str = None, ): """ Initialize a Bokeh artifact with the given figure. @@ -183,7 +186,10 @@ class PlotlyArtifact(Artifact): kind = "plotly" def __init__( - self, figure, key: str = None, target_path: str = None, + self, + figure, + key: str = None, + target_path: str = None, ): """ Initialize a Plotly artifact with the given figure. diff --git a/mlrun/builder.py b/mlrun/builder.py index 8273a36590..20ca0a91c2 100644 --- a/mlrun/builder.py +++ b/mlrun/builder.py @@ -351,9 +351,9 @@ def build_runtime( if build.base_image: runtime.spec.image = build.base_image elif runtime.kind in mlrun.mlconf.function_defaults.image_by_kind.to_dict(): - runtime.spec.image = mlrun.mlconf.function_defaults.image_by_kind.to_dict()[ - runtime.kind - ] + runtime.spec.image = ( + mlrun.mlconf.function_defaults.image_by_kind.to_dict()[runtime.kind] + ) if not runtime.spec.image: raise mlrun.errors.MLRunInvalidArgumentError( "The deployment was not successful because no image was specified or there are missing build parameters" diff --git a/mlrun/datastore/base.py b/mlrun/datastore/base.py index c6a95dddfd..db7c2d9fe8 100644 --- a/mlrun/datastore/base.py +++ b/mlrun/datastore/base.py @@ -369,7 +369,11 @@ def local(self): return self._local_path def as_df( - self, columns=None, df_module=None, format="", **kwargs, + self, + columns=None, + df_module=None, + format="", + **kwargs, ): """return a dataframe object (generated from the dataitem). diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index 4ccedd874d..3f690222da 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -98,22 +98,22 @@ def is_iterator(self): class CSVSource(BaseSourceDriver): """ - Reads CSV file as input source for a flow. - - :parameter name: name of the source - :parameter path: path to CSV file - :parameter key_field: the CSV field to be used as the key for events. May be an int (field index) or string - (field name) if with_header is True. Defaults to None (no key). Can be a list of keys. - :parameter time_field: the CSV field to be parsed as the timestamp for events. May be an int (field index) or - string (field name) if with_header is True. Defaults to None (no timestamp field). The field will be parsed - from isoformat (ISO-8601 as defined in datetime.fromisoformat()). In case the format is not isoformat, - timestamp_format (as defined in datetime.strptime()) should be passed in attributes. - :parameter schedule: string to configure scheduling of the ingestion job. - :parameter attributes: additional parameters to pass to storey. For example: - attributes={"timestamp_format": '%Y%m%d%H'} - :parameter parse_dates: Optional. List of columns (names or integers, other than time_field) that will be - attempted to parse as date column. - """ + Reads CSV file as input source for a flow. + + :parameter name: name of the source + :parameter path: path to CSV file + :parameter key_field: the CSV field to be used as the key for events. May be an int (field index) or string + (field name) if with_header is True. Defaults to None (no key). Can be a list of keys. + :parameter time_field: the CSV field to be parsed as the timestamp for events. May be an int (field index) or + string (field name) if with_header is True. Defaults to None (no timestamp field). The field will be parsed + from isoformat (ISO-8601 as defined in datetime.fromisoformat()). In case the format is not isoformat, + timestamp_format (as defined in datetime.strptime()) should be passed in attributes. + :parameter schedule: string to configure scheduling of the ingestion job. + :parameter attributes: additional parameters to pass to storey. For example: + attributes={"timestamp_format": '%Y%m%d%H'} + :parameter parse_dates: Optional. List of columns (names or integers, other than time_field) that will be + attempted to parse as date column. + """ kind = "csv" support_storey = True @@ -172,21 +172,21 @@ def is_iterator(self): class ParquetSource(BaseSourceDriver): """ - Reads Parquet file/dir as input source for a flow. - - :parameter name: name of the source - :parameter path: path to Parquet file or directory - :parameter key_field: the column to be used as the key for events. Can be a list of keys. - :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None - :parameter start_filter: datetime. If not None, the results will be filtered by partitions and - 'filter_column' > start_filter. Default is None - :parameter end_filter: datetime. If not None, the results will be filtered by partitions - 'filter_column' <= end_filter. Default is None - :parameter filter_column: Optional. if not None, the results will be filtered by this column and - start_filter & end_filter - :parameter schedule: string to configure scheduling of the ingestion job. For example '*/30 * * * *' will - cause the job to run every 30 minutes - :parameter attributes: additional parameters to pass to storey. + Reads Parquet file/dir as input source for a flow. + + :parameter name: name of the source + :parameter path: path to Parquet file or directory + :parameter key_field: the column to be used as the key for events. Can be a list of keys. + :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None + :parameter start_filter: datetime. If not None, the results will be filtered by partitions and + 'filter_column' > start_filter. Default is None + :parameter end_filter: datetime. If not None, the results will be filtered by partitions + 'filter_column' <= end_filter. Default is None + :parameter filter_column: Optional. if not None, the results will be filtered by this column and + start_filter & end_filter + :parameter schedule: string to configure scheduling of the ingestion job. For example '*/30 * * * *' will + cause the job to run every 30 minutes + :parameter attributes: additional parameters to pass to storey. """ kind = "parquet" @@ -261,33 +261,33 @@ def to_dataframe(self): class BigQuerySource(BaseSourceDriver): """ - Reads Google BigQuery query results as input source for a flow. - - example:: - - # use sql query - query_string = "SELECT * FROM `the-psf.pypi.downloads20210328` LIMIT 5000" - source = BigQuerySource("bq1", query=query_string, - gcp_project="my_project", - materialization_dataset="dataviews") - - # read a table - source = BigQuerySource("bq2", table="the-psf.pypi.downloads20210328", gcp_project="my_project") - - - :parameter name: source name - :parameter table: table name/path, cannot be used together with query - :parameter query: sql query string - :parameter materialization_dataset: for query with spark, The target dataset for the materialized view. - This dataset should be in same location as the view or the queried tables. - must be set to a dataset where the GCP user has table creation permission - :parameter chunksize: number of rows per chunk (default large single chunk) - :parameter key_field: the column to be used as the key for events. Can be a list of keys. - :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None - :parameter schedule: string to configure scheduling of the ingestion job. For example '*/30 * * * *' will - cause the job to run every 30 minutes - :parameter gcp_project: google cloud project name - :parameter spark_options: additional spart read options + Reads Google BigQuery query results as input source for a flow. + + example:: + + # use sql query + query_string = "SELECT * FROM `the-psf.pypi.downloads20210328` LIMIT 5000" + source = BigQuerySource("bq1", query=query_string, + gcp_project="my_project", + materialization_dataset="dataviews") + + # read a table + source = BigQuerySource("bq2", table="the-psf.pypi.downloads20210328", gcp_project="my_project") + + + :parameter name: source name + :parameter table: table name/path, cannot be used together with query + :parameter query: sql query string + :parameter materialization_dataset: for query with spark, The target dataset for the materialized view. + This dataset should be in same location as the view or the queried tables. + must be set to a dataset where the GCP user has table creation permission + :parameter chunksize: number of rows per chunk (default large single chunk) + :parameter key_field: the column to be used as the key for events. Can be a list of keys. + :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None + :parameter schedule: string to configure scheduling of the ingestion job. For example '*/30 * * * *' will + cause the job to run every 30 minutes + :parameter gcp_project: google cloud project name + :parameter spark_options: additional spart read options """ kind = "bigquery" @@ -454,11 +454,11 @@ def to_step(self, key_field=None, time_field=None, context=None): class DataFrameSource: """ - Reads data frame as input source for a flow. + Reads data frame as input source for a flow. - :parameter key_field: the column to be used as the key for events. Can be a list of keys. Defaults to None - :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None - :parameter context: MLRun context. Defaults to None + :parameter key_field: the column to be used as the key for events. Can be a list of keys. Defaults to None + :parameter time_field: the column to be parsed as the timestamp for events. Defaults to None + :parameter context: MLRun context. Defaults to None """ support_storey = True @@ -572,14 +572,14 @@ def __init__( **kwargs, ): """ - Sets stream source for the flow. If stream doesn't exist it will create it - - :param name: stream name. Default "stream" - :param group: consumer group. Default "serving" - :param seek_to: from where to consume the stream. Default earliest - :param shards: number of shards in the stream. Default 1 - :param retention_in_hours: if stream doesn't exist and it will be created set retention time. Default 24h - :param extra_attributes: additional nuclio trigger attributes (key/value dict) + Sets stream source for the flow. If stream doesn't exist it will create it + + :param name: stream name. Default "stream" + :param group: consumer group. Default "serving" + :param seek_to: from where to consume the stream. Default earliest + :param shards: number of shards in the stream. Default 1 + :param retention_in_hours: if stream doesn't exist and it will be created set retention time. Default 24h + :param extra_attributes: additional nuclio trigger attributes (key/value dict) """ attrs = { "group": group, @@ -631,13 +631,13 @@ def __init__( ): """Sets kafka source for the flow - :param brokers: list of broker IP addresses - :param topics: list of topic names on which to listen. - :param group: consumer group. Default "serving" - :param initial_offset: from where to consume the stream. Default earliest - :param partitions: Optional, A list of partitions numbers for which the function receives events. - :param sasl_user: Optional, user name to use for sasl authentications - :param sasl_pass: Optional, password to use for sasl authentications + :param brokers: list of broker IP addresses + :param topics: list of topic names on which to listen. + :param group: consumer group. Default "serving" + :param initial_offset: from where to consume the stream. Default earliest + :param partitions: Optional, A list of partitions numbers for which the function receives events. + :param sasl_user: Optional, user name to use for sasl authentications + :param sasl_pass: Optional, password to use for sasl authentications """ if isinstance(topics, str): topics = [topics] diff --git a/mlrun/datastore/store_resources.py b/mlrun/datastore/store_resources.py index a6b2519a51..cd67e6405d 100644 --- a/mlrun/datastore/store_resources.py +++ b/mlrun/datastore/store_resources.py @@ -154,7 +154,10 @@ def get_store_resource(uri, db=None, secrets=None, project=None): if resource.get("kind", "") == "link": # todo: support other link types (not just iter, move this to the db/api layer resource = db.read_artifact( - key, tag=tag, iter=resource.get("link_iteration", 0), project=project, + key, + tag=tag, + iter=resource.get("link_iteration", 0), + project=project, ) if resource: # import here to avoid circular imports diff --git a/mlrun/datastore/targets.py b/mlrun/datastore/targets.py index f2039c9d09..e4091a06e0 100644 --- a/mlrun/datastore/targets.py +++ b/mlrun/datastore/targets.py @@ -345,7 +345,12 @@ def _get_column_list(self, features, timestamp_key, key_columns, with_type=False return result def write_dataframe( - self, df, key_column=None, timestamp_key=None, chunk_id=0, **kwargs, + self, + df, + key_column=None, + timestamp_key=None, + chunk_id=0, + **kwargs, ) -> typing.Optional[int]: if hasattr(df, "rdd"): options = self.get_spark_options(key_column, timestamp_key) @@ -1080,7 +1085,9 @@ def write_dataframe( container, path = split_path(path_with_container) frames_client = get_frames_client( - token=access_key, address=config.v3io_framesd, container=container, + token=access_key, + address=config.v3io_framesd, + container=container, ) frames_client.write( diff --git a/mlrun/datastore/v3io.py b/mlrun/datastore/v3io.py index c19dea4118..d0c4f63520 100644 --- a/mlrun/datastore/v3io.py +++ b/mlrun/datastore/v3io.py @@ -142,8 +142,8 @@ def listdir(self, key): return [obj.key[subpath_length:] for obj in response.output.contents] def rm(self, path, recursive=False, maxdepth=None): - """ Recursive rm file/folder - Workaround for v3io-fs not supporting recursive directory removal """ + """Recursive rm file/folder + Workaround for v3io-fs not supporting recursive directory removal""" fs = self.get_filesystem() if isinstance(path, str): diff --git a/mlrun/db/base.py b/mlrun/db/base.py index de2d07f49b..9a8861acfc 100644 --- a/mlrun/db/base.py +++ b/mlrun/db/base.py @@ -150,7 +150,11 @@ def delete_project( pass @abstractmethod - def store_project(self, name: str, project: schemas.Project,) -> schemas.Project: + def store_project( + self, + name: str, + project: schemas.Project, + ) -> schemas.Project: pass @abstractmethod @@ -163,7 +167,10 @@ def patch_project( pass @abstractmethod - def create_project(self, project: schemas.Project,) -> schemas.Project: + def create_project( + self, + project: schemas.Project, + ) -> schemas.Project: pass @abstractmethod @@ -209,7 +216,11 @@ def list_features( @abstractmethod def list_entities( - self, project: str, name: str = None, tag: str = None, labels: List[str] = None, + self, + project: str, + name: str = None, + tag: str = None, + labels: List[str] = None, ) -> schemas.EntitiesOutput: pass diff --git a/mlrun/db/filedb.py b/mlrun/db/filedb.py index 842d0a9222..44e206e848 100644 --- a/mlrun/db/filedb.py +++ b/mlrun/db/filedb.py @@ -80,7 +80,7 @@ def get_log(self, uid, project="", offset=0, size=0): if offset: fp.seek(offset) if not size: - size = 2 ** 18 + size = 2**18 return "", fp.read(size) return "", None @@ -153,7 +153,10 @@ def list_runs( and match_value_options(state, run, "status.state") and match_value(uid, run, "metadata.uid") and match_times( - start_time_from, start_time_to, run, "status.start_time", + start_time_from, + start_time_to, + run, + "status.start_time", ) and match_times( last_update_time_from, @@ -375,7 +378,9 @@ def list_functions(self, name=None, project="", tag="", labels=None): labels = labels or [] logger.info(f"reading functions in {project} name/mask: {name} tag: {tag} ...") filepath = path.join( - self.dirpath, functions_dir, project or config.default_project, + self.dirpath, + functions_dir, + project or config.default_project, ) filepath += "/" @@ -480,7 +485,9 @@ def delete_project( raise NotImplementedError() def store_project( - self, name: str, project: mlrun.api.schemas.Project, + self, + name: str, + project: mlrun.api.schemas.Project, ) -> mlrun.api.schemas.Project: raise NotImplementedError() @@ -493,7 +500,8 @@ def patch_project( raise NotImplementedError() def create_project( - self, project: mlrun.api.schemas.Project, + self, + project: mlrun.api.schemas.Project, ) -> mlrun.api.schemas.Project: raise NotImplementedError() @@ -571,7 +579,11 @@ def list_features( raise NotImplementedError() def list_entities( - self, project: str, name: str = None, tag: str = None, labels: List[str] = None, + self, + project: str, + name: str = None, + tag: str = None, + labels: List[str] = None, ): raise NotImplementedError() @@ -597,7 +609,13 @@ def store_feature_set( raise NotImplementedError() def patch_feature_set( - self, name, feature_set, project="", tag=None, uid=None, patch_mode="replace", + self, + name, + feature_set, + project="", + tag=None, + uid=None, + patch_mode="replace", ): raise NotImplementedError() @@ -627,7 +645,13 @@ def list_feature_vectors( raise NotImplementedError() def store_feature_vector( - self, feature_vector, name=None, project="", tag=None, uid=None, versioned=True, + self, + feature_vector, + name=None, + project="", + tag=None, + uid=None, + versioned=True, ): raise NotImplementedError() diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index 78ae5e87f8..3d12633e08 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -68,7 +68,7 @@ def bool2str(val): class HTTPRunDB(RunDBInterface): - """ Interface for accessing and manipulating the :py:mod:`mlrun` persistent store, maintaining the full state + """Interface for accessing and manipulating the :py:mod:`mlrun` persistent store, maintaining the full state and catalog of objects that MLRun uses. The :py:class:`HTTPRunDB` class serves as a client-side proxy to the MLRun API service which maintains the actual data-store, accesses the server through REST APIs. @@ -123,8 +123,8 @@ def __repr__(self): @staticmethod def get_api_path_prefix(version: str = None) -> str: """ - :param version: API version to use, None (the default) will mean to use the default value from mlconf, - for un-versioned api set an empty string. + :param version: API version to use, None (the default) will mean to use the default value from mlconf, + for un-versioned api set an empty string. """ if version is not None: return f"api/{version}" if version else "api" @@ -151,23 +151,23 @@ def api_call( timeout=45, version=None, ): - """ Perform a direct REST API call on the :py:mod:`mlrun` API server. + """Perform a direct REST API call on the :py:mod:`mlrun` API server. - Caution: - For advanced usage - prefer using the various APIs exposed through this class, rather than - directly invoking REST calls. + Caution: + For advanced usage - prefer using the various APIs exposed through this class, rather than + directly invoking REST calls. - :param method: REST method (POST, GET, PUT...) - :param path: Path to endpoint executed, for example ``"projects"`` - :param error: Error to return if API invocation fails - :param body: Payload to be passed in the call. If using JSON objects, prefer using the ``json`` param - :param json: JSON payload to be passed in the call - :param headers: REST headers, passed as a dictionary: ``{"": ""}`` - :param timeout: API call timeout - :param version: API version to use, None (the default) will mean to use the default value from config, - for un-versioned api set an empty string. + :param method: REST method (POST, GET, PUT...) + :param path: Path to endpoint executed, for example ``"projects"`` + :param error: Error to return if API invocation fails + :param body: Payload to be passed in the call. If using JSON objects, prefer using the ``json`` param + :param json: JSON payload to be passed in the call + :param headers: REST headers, passed as a dictionary: ``{"": ""}`` + :param timeout: API call timeout + :param version: API version to use, None (the default) will mean to use the default value from config, + for un-versioned api set an empty string. - :return: Python HTTP response object + :return: Python HTTP response object """ url = self.get_base_api_url(path, version) kw = { @@ -238,7 +238,7 @@ def _path_of(self, prefix, project, uid): return f"{prefix}/{project}/{uid}" def connect(self, secrets=None): - """ Connect to the MLRun API server. Must be called prior to executing any other method. + """Connect to the MLRun API server. Must be called prior to executing any other method. The code utilizes the URL for the API server from the configuration - ``mlconf.dbpath``. For example:: @@ -343,7 +343,7 @@ def connect(self, secrets=None): return self def store_log(self, uid, project="", body=None, append=False): - """ Save a log persistently. + """Save a log persistently. :param uid: Log unique ID :param project: Project name for which this log belongs @@ -361,7 +361,7 @@ def store_log(self, uid, project="", body=None, append=False): self.api_call("POST", path, error, params, body) def get_log(self, uid, project="", offset=0, size=-1): - """ Retrieve a log. + """Retrieve a log. :param uid: Log unique ID :param project: Project name for which the log belongs @@ -386,7 +386,7 @@ def get_log(self, uid, project="", offset=0, size=-1): return "unknown", resp.content def watch_log(self, uid, project="", watch=True, offset=0): - """ Retrieve logs of a running process, and watch the progress of the execution until it completes. This + """Retrieve logs of a running process, and watch the progress of the execution until it completes. This method will print out the logs and continue to periodically poll for, and print, new logs as long as the state of the runtime which generates this log is either ``pending`` or ``running``. @@ -419,7 +419,7 @@ def watch_log(self, uid, project="", watch=True, offset=0): return state def store_run(self, struct, uid, project="", iter=0): - """ Store run details in the DB. This method is usually called from within other :py:mod:`mlrun` flows + """Store run details in the DB. This method is usually called from within other :py:mod:`mlrun` flows and not called directly by the user.""" path = self._path_of("run", project, uid) @@ -429,7 +429,7 @@ def store_run(self, struct, uid, project="", iter=0): self.api_call("POST", path, error, params=params, body=body) def update_run(self, updates: dict, uid, project="", iter=0): - """ Update the details of a stored run in the DB.""" + """Update the details of a stored run in the DB.""" path = self._path_of("run", project, uid) params = {"iter": iter} @@ -449,7 +449,7 @@ def abort_run(self, uid, project="", iter=0): ) def read_run(self, uid, project="", iter=0): - """ Read the details of a stored run from the DB. + """Read the details of a stored run from the DB. :param uid: The run's unique ID. :param project: Project name. @@ -463,7 +463,7 @@ def read_run(self, uid, project="", iter=0): return resp.json()["data"] def del_run(self, uid, project="", iter=0): - """ Delete details of a specific run from DB. + """Delete details of a specific run from DB. :param uid: Unique ID for the specific run to delete. :param project: Project that the run belongs to. @@ -494,7 +494,7 @@ def list_runs( partition_sort_by: Union[schemas.SortField, str] = None, partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, ) -> RunList: - """ Retrieve a list of runs, filtered by various options. + """Retrieve a list of runs, filtered by various options. Example:: runs = db.list_runs(name='download', project='iris', labels='owner=admin') @@ -556,7 +556,7 @@ def list_runs( return RunList(resp.json()["runs"]) def del_runs(self, name=None, project=None, labels=None, state=None, days_ago=0): - """ Delete a group of runs identified by the parameters of the function. + """Delete a group of runs identified by the parameters of the function. Example:: @@ -581,7 +581,7 @@ def del_runs(self, name=None, project=None, labels=None, state=None, days_ago=0) self.api_call("DELETE", "runs", error, params=params) def store_artifact(self, key, artifact, uid, iter=None, tag=None, project=""): - """ Store an artifact in the DB. + """Store an artifact in the DB. :param key: Identifying key of the artifact. :param artifact: The actual artifact to store. @@ -605,7 +605,7 @@ def store_artifact(self, key, artifact, uid, iter=None, tag=None, project=""): self.api_call("POST", path, error, params=params, body=body) def read_artifact(self, key, tag=None, iter=None, project=""): - """ Read an artifact, identified by its key, tag and iteration.""" + """Read an artifact, identified by its key, tag and iteration.""" project = project or config.default_project tag = tag or "latest" @@ -616,7 +616,7 @@ def read_artifact(self, key, tag=None, iter=None, project=""): return resp.json()["data"] def del_artifact(self, key, tag=None, project=""): - """ Delete an artifact.""" + """Delete an artifact.""" path = self._path_of("artifact", project, key) # TODO: uid? params = { @@ -639,7 +639,7 @@ def list_artifacts( kind: str = None, category: Union[str, schemas.ArtifactCategories] = None, ) -> ArtifactList: - """ List artifacts filtered by various parameters. + """List artifacts filtered by various parameters. Examples:: @@ -685,7 +685,7 @@ def list_artifacts( return values def del_artifacts(self, name=None, project=None, tag=None, labels=None, days_ago=0): - """ Delete artifacts referenced by the parameters. + """Delete artifacts referenced by the parameters. :param name: Name of artifacts to delete. Note that this is a like query, and is case-insensitive. See :py:func:`~list_artifacts` for more details. @@ -706,7 +706,7 @@ def del_artifacts(self, name=None, project=None, tag=None, labels=None, days_ago self.api_call("DELETE", "artifacts", error, params=params) def list_artifact_tags(self, project=None) -> List[str]: - """ Return a list of all the tags assigned to artifacts in the scope of the given project.""" + """Return a list of all the tags assigned to artifacts in the scope of the given project.""" project = project or config.default_project error_message = f"Failed listing artifact tags. project={project}" @@ -716,7 +716,7 @@ def list_artifact_tags(self, project=None) -> List[str]: return response.json()["tags"] def store_function(self, function, name, project="", tag=None, versioned=False): - """ Store a function object. Function is identified by its name and tag, and can be versioned.""" + """Store a function object. Function is identified by its name and tag, and can be versioned.""" params = {"tag": tag, "versioned": versioned} project = project or config.default_project @@ -731,7 +731,7 @@ def store_function(self, function, name, project="", tag=None, versioned=False): return resp.json().get("hash_key") def get_function(self, name, project="", tag=None, hash_key=""): - """ Retrieve details of a specific function, identified by its name and potentially a tag or function hash.""" + """Retrieve details of a specific function, identified by its name and potentially a tag or function hash.""" params = {"tag": tag, "hash_key": hash_key} project = project or config.default_project @@ -741,7 +741,7 @@ def get_function(self, name, project="", tag=None, hash_key=""): return resp.json()["func"] def delete_function(self, name: str, project: str = ""): - """ Delete a function belonging to a specific project.""" + """Delete a function belonging to a specific project.""" project = project or config.default_project path = f"projects/{project}/functions/{name}" @@ -749,7 +749,7 @@ def delete_function(self, name: str, project: str = ""): self.api_call("DELETE", path, error_message) def list_functions(self, name=None, project=None, tag=None, labels=None): - """ Retrieve a list of functions, filtered by specific criteria. + """Retrieve a list of functions, filtered by specific criteria. :param name: Return only functions with a specific name. :param project: Return functions belonging to this project. If not specified, the default project is used. @@ -780,7 +780,7 @@ def list_runtime_resources( mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, ]: - """ List current runtime resources, which are usually (but not limited to) Kubernetes pods or CRDs. + """List current runtime resources, which are usually (but not limited to) Kubernetes pods or CRDs. Function applies for runs of type ``['dask', 'job', 'spark', 'remote-spark', 'mpijob']``, and will return per runtime kind a list of the runtime resources (which may have already completed their execution). @@ -835,8 +835,7 @@ def list_runtime_resources( ) def list_runtimes(self, label_selector: str = None) -> List: - """ Deprecated use :py:func:`~list_runtime_resources` instead - """ + """Deprecated use :py:func:`~list_runtime_resources` instead""" warnings.warn( "This method is deprecated, use list_runtime_resources instead" "This will be removed in 0.9.0", @@ -849,8 +848,7 @@ def list_runtimes(self, label_selector: str = None) -> List: return resp.json() def get_runtime(self, kind: str, label_selector: str = None) -> Dict: - """ Deprecated use :py:func:`~list_runtime_resources` (with kind filter) instead - """ + """Deprecated use :py:func:`~list_runtime_resources` (with kind filter) instead""" warnings.warn( "This method is deprecated, use list_runtime_resources (with kind filter) instead" "This will be removed in 0.9.0", @@ -872,7 +870,7 @@ def delete_runtime_resources( force: bool = False, grace_period: int = None, ) -> mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput: - """ Delete all runtime resources which are in terminal state. + """Delete all runtime resources which are in terminal state. :param project: Delete only runtime resources of a specific project, by default None, which will delete only from the projects you're authorized to delete from. @@ -916,10 +914,12 @@ def delete_runtime_resources( return structured_dict def delete_runtimes( - self, label_selector: str = None, force: bool = False, grace_period: int = None, + self, + label_selector: str = None, + force: bool = False, + grace_period: int = None, ): - """ Deprecated use :py:func:`~delete_runtime_resources` instead - """ + """Deprecated use :py:func:`~delete_runtime_resources` instead""" warnings.warn( "This method is deprecated, use delete_runtime_resources instead" "This will be removed in 0.9.0", @@ -943,8 +943,7 @@ def delete_runtime( force: bool = False, grace_period: int = None, ): - """ Deprecated use :py:func:`~delete_runtime_resources` (with kind filter) instead - """ + """Deprecated use :py:func:`~delete_runtime_resources` (with kind filter) instead""" warnings.warn( "This method is deprecated, use delete_runtime_resources (with kind filter) instead" "This will be removed in 0.9.0", @@ -972,8 +971,7 @@ def delete_runtime_object( force: bool = False, grace_period: int = None, ): - """ Deprecated use :py:func:`~delete_runtime_resources` (with kind and object_id filter) instead - """ + """Deprecated use :py:func:`~delete_runtime_resources` (with kind and object_id filter) instead""" warnings.warn( "This method is deprecated, use delete_runtime_resources (with kind and object_id filter) instead" "This will be removed in 0.9.0", @@ -993,7 +991,7 @@ def delete_runtime_object( self.api_call("DELETE", path, error, params=params) def create_schedule(self, project: str, schedule: schemas.ScheduleInput): - """ Create a new schedule on the given project. The details on the actual object to schedule as well as the + """Create a new schedule on the given project. The details on the actual object to schedule as well as the schedule itself are within the schedule object provided. The :py:class:`~ScheduleCronTrigger` follows the guidelines in https://apscheduler.readthedocs.io/en/v3.6.3/modules/triggers/cron.html. @@ -1023,7 +1021,7 @@ def create_schedule(self, project: str, schedule: schemas.ScheduleInput): def update_schedule( self, project: str, name: str, schedule: schemas.ScheduleUpdate ): - """ Update an existing schedule, replace it with the details contained in the schedule object.""" + """Update an existing schedule, replace it with the details contained in the schedule object.""" project = project or config.default_project path = f"projects/{project}/schedules/{name}" @@ -1034,7 +1032,7 @@ def update_schedule( def get_schedule( self, project: str, name: str, include_last_run: bool = False ) -> schemas.ScheduleOutput: - """ Retrieve details of the schedule in question. Besides returning the details of the schedule object itself, + """Retrieve details of the schedule in question. Besides returning the details of the schedule object itself, this function also returns the next scheduled run for this specific schedule, as well as potentially the results of the last run executed through this schedule. @@ -1058,7 +1056,7 @@ def list_schedules( kind: schemas.ScheduleKinds = None, include_last_run: bool = False, ) -> schemas.SchedulesOutput: - """ Retrieve list of schedules of specific name or kind. + """Retrieve list of schedules of specific name or kind. :param project: Project name. :param name: Name of schedule to retrieve. Can be omitted to list all schedules. @@ -1075,7 +1073,7 @@ def list_schedules( return schemas.SchedulesOutput(**resp.json()) def delete_schedule(self, project: str, name: str): - """ Delete a specific schedule by name. """ + """Delete a specific schedule by name.""" project = project or config.default_project path = f"projects/{project}/schedules/{name}" @@ -1083,7 +1081,7 @@ def delete_schedule(self, project: str, name: str): self.api_call("DELETE", path, error_message) def invoke_schedule(self, project: str, name: str): - """ Execute the object referenced by the schedule immediately. """ + """Execute the object referenced by the schedule immediately.""" project = project or config.default_project path = f"projects/{project}/schedules/{name}/invoke" @@ -1098,7 +1096,7 @@ def remote_builder( skip_deployed=False, builder_env=None, ): - """ Build the pod image for a function, for execution on a remote cluster. This is executed by the MLRun + """Build the pod image for a function, for execution on a remote cluster. This is executed by the MLRun API server, and creates a Docker image out of the function provided and any specific build instructions provided within. This is a pre-requisite for remotely executing a function, unless using a pre-deployed image. @@ -1140,7 +1138,7 @@ def get_builder_status( last_log_timestamp=0, verbose=False, ): - """ Retrieve the status of a build operation currently in progress. + """Retrieve the status of a build operation currently in progress. :param func: Function object that is being built. :param offset: Offset into the build logs to retrieve logs from. @@ -1199,7 +1197,7 @@ def get_builder_status( return text, last_log_timestamp def remote_start(self, func_url) -> schemas.BackgroundTask: - """ Execute a function remotely, Used for ``dask`` functions. + """Execute a function remotely, Used for ``dask`` functions. :param func_url: URL to the function to be executed. :returns: A BackgroundTask object, with details on execution process and its status. @@ -1224,9 +1222,11 @@ def remote_start(self, func_url) -> schemas.BackgroundTask: return schemas.BackgroundTask(**resp.json()) def get_project_background_task( - self, project: str, name: str, + self, + project: str, + name: str, ) -> schemas.BackgroundTask: - """ Retrieve updated information on a background task being executed.""" + """Retrieve updated information on a background task being executed.""" project = project or config.default_project path = f"projects/{project}/background-tasks/{name}" @@ -1237,7 +1237,7 @@ def get_project_background_task( return schemas.BackgroundTask(**response.json()) def remote_status(self, project, name, kind, selector): - """ Retrieve status of a function being executed remotely (relevant to ``dask`` functions). + """Retrieve status of a function being executed remotely (relevant to ``dask`` functions). :param project: The project of the function :param name: The name of the function @@ -1261,7 +1261,7 @@ def remote_status(self, project, name, kind, selector): def submit_job( self, runspec, schedule: Union[str, schemas.ScheduleCronTrigger] = None ): - """ Submit a job for remote execution. + """Submit a job for remote execution. :param runspec: The runtime object spec (Task) to execute. :param schedule: Whether to schedule this job using a Cron trigger. If not specified, the job will be submitted @@ -1299,7 +1299,7 @@ def submit_pipeline( ops=None, ttl=None, ): - """ Submit a KFP pipeline for execution. + """Submit a KFP pipeline for execution. :param project: The project of the pipeline :param pipeline: Pipeline function or path to .yaml/.zip pipeline file. @@ -1373,7 +1373,7 @@ def list_pipelines( ] = mlrun.api.schemas.PipelinesFormat.metadata_only, page_size: int = None, ) -> mlrun.api.schemas.PipelinesOutput: - """ Retrieve a list of KFP pipelines. This function can be invoked to get all pipelines from all projects, + """Retrieve a list of KFP pipelines. This function can be invoked to get all pipelines from all projects, by specifying ``project=*``, in which case pagination can be used and the various sorting and pagination properties can be applied. If a specific project is requested, then the pagination options cannot be used and pagination is not applied. @@ -1422,7 +1422,7 @@ def get_pipeline( ] = mlrun.api.schemas.PipelinesFormat.summary, project: str = None, ): - """ Retrieve details of a specific pipeline using its run ID (as provided when the pipeline was executed).""" + """Retrieve details of a specific pipeline using its run ID (as provided when the pipeline was executed).""" if isinstance(format_, mlrun.api.schemas.PipelinesFormat): format_ = format_.value @@ -1457,7 +1457,7 @@ def _resolve_reference(tag, uid): def create_feature_set( self, feature_set: Union[dict, schemas.FeatureSet], project="", versioned=True ) -> dict: - """ Create a new :py:class:`~mlrun.feature_store.FeatureSet` and save in the :py:mod:`mlrun` DB. The + """Create a new :py:class:`~mlrun.feature_store.FeatureSet` and save in the :py:mod:`mlrun` DB. The feature-set must not previously exist in the DB. :param feature_set: The new :py:class:`~mlrun.feature_store.FeatureSet` to create. @@ -1480,14 +1480,18 @@ def create_feature_set( name = feature_set["metadata"]["name"] error_message = f"Failed creating feature-set {project}/{name}" resp = self.api_call( - "POST", path, error_message, params=params, body=dict_to_json(feature_set), + "POST", + path, + error_message, + params=params, + body=dict_to_json(feature_set), ) return resp.json() def get_feature_set( self, name: str, project: str = "", tag: str = None, uid: str = None ) -> FeatureSet: - """ Retrieve a ~mlrun.feature_store.FeatureSet` object. If both ``tag`` and ``uid`` are not specified, then + """Retrieve a ~mlrun.feature_store.FeatureSet` object. If both ``tag`` and ``uid`` are not specified, then the object tagged ``latest`` will be retrieved. :param name: Name of object to retrieve. @@ -1511,7 +1515,7 @@ def list_features( entities: List[str] = None, labels: List[str] = None, ) -> List[dict]: - """ List feature-sets which contain specific features. This function may return multiple versions of the same + """List feature-sets which contain specific features. This function may return multiple versions of the same feature-set if a specific tag is not requested. Note that the various filters of this function actually refer to the feature-set object containing the features, not to the features themselves. @@ -1541,9 +1545,13 @@ def list_features( return resp.json()["features"] def list_entities( - self, project: str, name: str = None, tag: str = None, labels: List[str] = None, + self, + project: str, + name: str = None, + tag: str = None, + labels: List[str] = None, ) -> List[dict]: - """ Retrieve a list of entities and their mapping to the containing feature-sets. This function is similar + """Retrieve a list of entities and their mapping to the containing feature-sets. This function is similar to the :py:func:`~list_features` function, and uses the same logic. However, the entities are matched against the name rather than the features. """ @@ -1593,7 +1601,7 @@ def list_feature_sets( partition_sort_by: Union[schemas.SortField, str] = None, partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, ) -> List[FeatureSet]: - """ Retrieve a list of feature-sets matching the criteria provided. + """Retrieve a list of feature-sets matching the criteria provided. :param project: Project name. :param name: Name of feature-set to match. This is a like query, and is case-insensitive. @@ -1652,7 +1660,7 @@ def store_feature_set( uid=None, versioned=True, ) -> dict: - """ Save a :py:class:`~mlrun.feature_store.FeatureSet` object in the :py:mod:`mlrun` DB. The + """Save a :py:class:`~mlrun.feature_store.FeatureSet` object in the :py:mod:`mlrun` DB. The feature-set can be either a new object or a modification to existing object referenced by the params of the function. @@ -1692,7 +1700,7 @@ def patch_feature_set( uid=None, patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, ): - """ Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureSet` object. + """Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureSet` object. The object is identified by its name (and project it belongs to), as well as optionally a ``tag`` or its ``uid`` (for versioned object). If both ``tag`` and ``uid`` are omitted then the object with tag ``latest`` is modified. @@ -1727,7 +1735,7 @@ def patch_feature_set( ) def delete_feature_set(self, name, project="", tag=None, uid=None): - """ Delete a :py:class:`~mlrun.feature_store.FeatureSet` object from the DB. + """Delete a :py:class:`~mlrun.feature_store.FeatureSet` object from the DB. If ``tag`` or ``uid`` are specified, then just the version referenced by them will be deleted. Using both is not allowed. If none are specified, then all instances of the object whose name is ``name`` will be deleted. @@ -1748,7 +1756,7 @@ def create_feature_vector( project="", versioned=True, ) -> dict: - """ Create a new :py:class:`~mlrun.feature_store.FeatureVector` and save in the :py:mod:`mlrun` DB. + """Create a new :py:class:`~mlrun.feature_store.FeatureVector` and save in the :py:mod:`mlrun` DB. :param feature_vector: The new :py:class:`~mlrun.feature_store.FeatureVector` to create. :param project: Name of project this feature-vector belongs to. @@ -1781,8 +1789,8 @@ def create_feature_vector( def get_feature_vector( self, name: str, project: str = "", tag: str = None, uid: str = None ) -> FeatureVector: - """ Return a specific feature-vector referenced by its tag or uid. If none are provided, ``latest`` tag will - be used. """ + """Return a specific feature-vector referenced by its tag or uid. If none are provided, ``latest`` tag will + be used.""" project = project or config.default_project reference = self._resolve_reference(tag, uid) @@ -1803,7 +1811,7 @@ def list_feature_vectors( partition_sort_by: Union[schemas.SortField, str] = None, partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, ) -> List[FeatureVector]: - """ Retrieve a list of feature-vectors matching the criteria provided. + """Retrieve a list of feature-vectors matching the criteria provided. :param project: Project name. :param name: Name of feature-vector to match. This is a like query, and is case-insensitive. @@ -1858,7 +1866,7 @@ def store_feature_vector( uid=None, versioned=True, ) -> dict: - """ Store a :py:class:`~mlrun.feature_store.FeatureVector` object in the :py:mod:`mlrun` DB. The + """Store a :py:class:`~mlrun.feature_store.FeatureVector` object in the :py:mod:`mlrun` DB. The feature-vector can be either a new object or a modification to existing object referenced by the params of the function. @@ -1900,7 +1908,7 @@ def patch_feature_vector( uid=None, patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, ): - """ Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureVector` object. + """Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureVector` object. The object is identified by its name (and project it belongs to), as well as optionally a ``tag`` or its ``uid`` (for versioned object). If both ``tag`` and ``uid`` are omitted then the object with tag ``latest`` is modified. @@ -1930,7 +1938,7 @@ def patch_feature_vector( ) def delete_feature_vector(self, name, project="", tag=None, uid=None): - """ Delete a :py:class:`~mlrun.feature_store.FeatureVector` object from the DB. + """Delete a :py:class:`~mlrun.feature_store.FeatureVector` object from the DB. If ``tag`` or ``uid`` are specified, then just the version referenced by them will be deleted. Using both is not allowed. If none are specified, then all instances of the object whose name is ``name`` will be deleted. @@ -1953,7 +1961,7 @@ def list_projects( labels: List[str] = None, state: Union[str, mlrun.api.schemas.ProjectState] = None, ) -> List[Union[mlrun.projects.MlrunProject, str]]: - """ Return a list of the existing projects, potentially filtered by specific criteria. + """Return a list of the existing projects, potentially filtered by specific criteria. :param owner: List only projects belonging to this specific owner. :param format_: Format of the results. Possible values are: @@ -1991,7 +1999,7 @@ def list_projects( ) def get_project(self, name: str) -> mlrun.projects.MlrunProject: - """ Get details for a specific project.""" + """Get details for a specific project.""" if not name: raise MLRunInvalidArgumentError("Name must be provided") @@ -2008,7 +2016,7 @@ def delete_project( str, mlrun.api.schemas.DeletionStrategy ] = mlrun.api.schemas.DeletionStrategy.default(), ): - """ Delete a project. + """Delete a project. :param name: Name of the project to delete. :param deletion_strategy: How to treat child objects of the project. Possible values are: @@ -2032,7 +2040,7 @@ def store_project( name: str, project: Union[dict, mlrun.projects.MlrunProject, mlrun.api.schemas.Project], ) -> mlrun.projects.MlrunProject: - """ Store a project in the DB. This operation will overwrite existing project of the same name if exists.""" + """Store a project in the DB. This operation will overwrite existing project of the same name if exists.""" path = f"projects/{name}" error_message = f"Failed storing project {name}" @@ -2041,7 +2049,10 @@ def store_project( elif isinstance(project, mlrun.projects.MlrunProject): project = project.to_dict() response = self.api_call( - "PUT", path, error_message, body=dict_to_json(project), + "PUT", + path, + error_message, + body=dict_to_json(project), ) if response.status_code == http.HTTPStatus.ACCEPTED: return self._wait_for_project_to_reach_terminal_state(name) @@ -2053,7 +2064,7 @@ def patch_project( project: dict, patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, ) -> mlrun.projects.MlrunProject: - """ Patch an existing project object. + """Patch an existing project object. :param name: Name of project to patch. :param project: The actual changes to the project object. @@ -2075,7 +2086,7 @@ def create_project( self, project: Union[dict, mlrun.projects.MlrunProject, mlrun.api.schemas.Project], ) -> mlrun.projects.MlrunProject: - """ Create a new project. A project with the same name must not exist prior to creation.""" + """Create a new project. A project with the same name must not exist prior to creation.""" if isinstance(project, mlrun.api.schemas.Project): project = project.dict() @@ -2084,7 +2095,10 @@ def create_project( project_name = project["metadata"]["name"] error_message = f"Failed creating project {project_name}" response = self.api_call( - "POST", "projects", error_message, body=dict_to_json(project), + "POST", + "projects", + error_message, + body=dict_to_json(project), ) if response.status_code == http.HTTPStatus.ACCEPTED: return self._wait_for_project_to_reach_terminal_state(project_name) @@ -2136,7 +2150,7 @@ def create_project_secrets( ] = schemas.SecretProviderName.kubernetes, secrets: dict = None, ): - """ Create project-context secrets using either ``vault`` or ``kubernetes`` provider. + """Create project-context secrets using either ``vault`` or ``kubernetes`` provider. When using with Vault, this will create needed Vault structures for storing secrets in project-context, and store a set of secret values. The method generates Kubernetes service-account and the Vault authentication structures that are required for function Pods to authenticate with Vault and be able to extract secret values @@ -2169,7 +2183,10 @@ def create_project_secrets( body = secrets_input.dict() error_message = f"Failed creating secret provider {project}/{provider}" self.api_call( - "POST", path, error_message, body=dict_to_json(body), + "POST", + path, + error_message, + body=dict_to_json(body), ) def list_project_secrets( @@ -2181,7 +2198,7 @@ def list_project_secrets( ] = schemas.SecretProviderName.kubernetes, secrets: List[str] = None, ) -> schemas.SecretsData: - """ Retrieve project-context secrets from Vault. + """Retrieve project-context secrets from Vault. Note: This method for Vault functionality is currently in technical preview, and requires a HashiCorp Vault @@ -2208,7 +2225,11 @@ def list_project_secrets( headers = {schemas.HeaderNames.secret_store_token: token} error_message = f"Failed retrieving secrets {project}/{provider}" result = self.api_call( - "GET", path, error_message, params=params, headers=headers, + "GET", + path, + error_message, + params=params, + headers=headers, ) return schemas.SecretsData(**result.json()) @@ -2220,7 +2241,7 @@ def list_project_secret_keys( ] = schemas.SecretProviderName.kubernetes, token: str = None, ) -> schemas.SecretKeysData: - """ Retrieve project-context secret keys from Vault or Kubernetes. + """Retrieve project-context secret keys from Vault or Kubernetes. Note: This method for Vault functionality is currently in technical preview, and requires a HashiCorp Vault @@ -2250,7 +2271,11 @@ def list_project_secret_keys( ) error_message = f"Failed retrieving secret keys {project}/{provider}" result = self.api_call( - "GET", path, error_message, params=params, headers=headers, + "GET", + path, + error_message, + params=params, + headers=headers, ) return schemas.SecretKeysData(**result.json()) @@ -2262,7 +2287,7 @@ def delete_project_secrets( ] = schemas.SecretProviderName.kubernetes, secrets: List[str] = None, ): - """ Delete project-context secrets from Kubernetes. + """Delete project-context secrets from Kubernetes. :param project: The project name. :param provider: The name of the secrets-provider to work with. Currently only ``kubernetes`` is supported. @@ -2276,7 +2301,10 @@ def delete_project_secrets( params = {"provider": provider, "secret": secrets} error_message = f"Failed deleting secrets {project}/{provider}" self.api_call( - "DELETE", path, error_message, params=params, + "DELETE", + path, + error_message, + params=params, ) def create_user_secrets( @@ -2287,7 +2315,7 @@ def create_user_secrets( ] = schemas.SecretProviderName.vault, secrets: dict = None, ): - """ Create user-context secret in Vault. Please refer to :py:func:`create_project_secrets` for more details + """Create user-context secret in Vault. Please refer to :py:func:`create_project_secrets` for more details and status of this functionality. Note: @@ -2302,12 +2330,17 @@ def create_user_secrets( provider = provider.value path = "user-secrets" secrets_creation_request = schemas.UserSecretCreationRequest( - user=user, provider=provider, secrets=secrets, + user=user, + provider=provider, + secrets=secrets, ) body = secrets_creation_request.dict() error_message = f"Failed creating user secrets - {user}" self.api_call( - "POST", path, error_message, body=dict_to_json(body), + "POST", + path, + error_message, + body=dict_to_json(body), ) @staticmethod @@ -2378,7 +2411,10 @@ def create_or_patch_model_endpoint( ) def delete_model_endpoint_record( - self, project: str, endpoint_id: str, access_key: Optional[str] = None, + self, + project: str, + endpoint_id: str, + access_key: Optional[str] = None, ): """ Deletes the KV record of a given model endpoint, project and endpoint_id are used for lookup @@ -2396,7 +2432,9 @@ def delete_model_endpoint_record( path = f"projects/{project}/model-endpoints/{endpoint_id}" self.api_call( - method="DELETE", path=path, headers={"X-V3io-Access-Key": access_key}, + method="DELETE", + path=path, + headers={"X-V3io-Access-Key": access_key}, ) def list_model_endpoints( @@ -2679,7 +2717,7 @@ def get_marketplace_item( def verify_authorization( self, authorization_verification_input: schemas.AuthorizationVerificationInput ): - """ Verifies authorization for the provided action on the provided resource. + """Verifies authorization for the provided action on the provided resource. :param authorization_verification_input: Instance of :py:class:`~mlrun.api.schemas.AuthorizationVerificationInput` that includes all the needed parameters for diff --git a/mlrun/db/sqldb.py b/mlrun/db/sqldb.py index 975d2b1ff1..a74a3ea336 100644 --- a/mlrun/db/sqldb.py +++ b/mlrun/db/sqldb.py @@ -35,7 +35,9 @@ class SQLDB(RunDBInterface): def __init__( - self, dsn, session=None, + self, + dsn, + session=None, ): self.session = session self.dsn = dsn @@ -51,28 +53,47 @@ def store_log(self, uid, project="", body=b"", append=False): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Logs().store_log, body, project, uid, append, + mlrun.api.crud.Logs().store_log, + body, + project, + uid, + append, ) def get_log(self, uid, project="", offset=0, size=0): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Logs().get_logs, self.session, project, uid, size, offset, + mlrun.api.crud.Logs().get_logs, + self.session, + project, + uid, + size, + offset, ) def store_run(self, struct, uid, project="", iter=0): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Runs().store_run, self.session, struct, uid, iter, project, + mlrun.api.crud.Runs().store_run, + self.session, + struct, + uid, + iter, + project, ) def update_run(self, updates: dict, uid, project="", iter=0): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Runs().update_run, self.session, project, uid, iter, updates, + mlrun.api.crud.Runs().update_run, + self.session, + project, + uid, + iter, + updates, ) def abort_run(self, uid, project="", iter=0): @@ -82,7 +103,11 @@ def read_run(self, uid, project=None, iter=None): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Runs().get_run, self.session, uid, iter, project, + mlrun.api.crud.Runs().get_run, + self.session, + uid, + iter, + project, ) def list_runs( @@ -131,7 +156,11 @@ def del_run(self, uid, project=None, iter=None): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Runs().delete_run, self.session, uid, iter, project, + mlrun.api.crud.Runs().delete_run, + self.session, + uid, + iter, + project, ) def del_runs(self, name=None, project=None, labels=None, state=None, days_ago=0): @@ -210,7 +239,11 @@ def del_artifact(self, key, tag="", project=""): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Artifacts().delete_artifact, self.session, key, tag, project, + mlrun.api.crud.Artifacts().delete_artifact, + self.session, + key, + tag, + project, ) def del_artifacts(self, name="", project="", tag="", labels=None): @@ -254,7 +287,10 @@ def delete_function(self, name: str, project: str = ""): import mlrun.api.crud return self._transform_db_error( - mlrun.api.crud.Functions().delete_function, self.session, project, name, + mlrun.api.crud.Functions().delete_function, + self.session, + project, + name, ) def list_functions(self, name=None, project=None, tag=None, labels=None): @@ -281,7 +317,9 @@ def list_schedules(self): return self._transform_db_error(self.db.list_schedules, self.session) def store_project( - self, name: str, project: mlrun.api.schemas.Project, + self, + name: str, + project: mlrun.api.schemas.Project, ) -> mlrun.api.schemas.Project: raise NotImplementedError() @@ -294,7 +332,8 @@ def patch_project( raise NotImplementedError() def create_project( - self, project: mlrun.api.schemas.Project, + self, + project: mlrun.api.schemas.Project, ) -> mlrun.api.schemas.Project: raise NotImplementedError() @@ -373,7 +412,11 @@ def list_features( ) def list_entities( - self, project: str, name: str = None, tag: str = None, labels: List[str] = None, + self, + project: str, + name: str = None, + tag: str = None, + labels: List[str] = None, ): import mlrun.api.crud @@ -527,7 +570,13 @@ def list_feature_vectors( ) def store_feature_vector( - self, feature_vector, name=None, project="", tag=None, uid=None, versioned=True, + self, + feature_vector, + name=None, + project="", + tag=None, + uid=None, + versioned=True, ): import mlrun.api.crud diff --git a/mlrun/execution.py b/mlrun/execution.py index 90e1ce7400..f3dfa27e96 100644 --- a/mlrun/execution.py +++ b/mlrun/execution.py @@ -321,7 +321,7 @@ def tag(self): @property def iteration(self): - """child iteration index, for hyper parameters """ + """child iteration index, for hyper parameters""" return self._iteration @property diff --git a/mlrun/feature_store/api.py b/mlrun/feature_store/api.py index 10a2d3a25c..ca1cf36162 100644 --- a/mlrun/feature_store/api.py +++ b/mlrun/feature_store/api.py @@ -467,7 +467,10 @@ def ingest( ) if schema_options: preview( - featureset, source, options=schema_options, namespace=namespace, + featureset, + source, + options=schema_options, + namespace=namespace, ) infer_stats = InferOptions.get_common_options( infer_options, InferOptions.all_stats() @@ -477,7 +480,11 @@ def ingest( targets = targets or featureset.spec.targets or get_default_targets() df = init_featureset_graph( - source, featureset, namespace, targets=targets, return_df=return_df, + source, + featureset, + namespace, + targets=targets, + return_df=return_df, ) if not InferOptions.get_common_options( infer_stats, InferOptions.Index @@ -869,7 +876,7 @@ def get_feature_vector(uri, project=None): def delete_feature_set(name, project="", tag=None, uid=None, force=False): - """ Delete a :py:class:`~mlrun.feature_store.FeatureSet` object from the DB. + """Delete a :py:class:`~mlrun.feature_store.FeatureSet` object from the DB. :param name: Name of the object to delete :param project: Name of the object's project :param tag: Specific object's version tag @@ -891,7 +898,7 @@ def delete_feature_set(name, project="", tag=None, uid=None, force=False): def delete_feature_vector(name, project="", tag=None, uid=None): - """ Delete a :py:class:`~mlrun.feature_store.FeatureVector` object from the DB. + """Delete a :py:class:`~mlrun.feature_store.FeatureVector` object from the DB. :param name: Name of the object to delete :param project: Name of the object's project :param tag: Specific object's version tag diff --git a/mlrun/feature_store/common.py b/mlrun/feature_store/common.py index 4377606cf1..9c189f3cce 100644 --- a/mlrun/feature_store/common.py +++ b/mlrun/feature_store/common.py @@ -85,8 +85,10 @@ def get_feature_set_by_uri(uri, project=None): """get feature set object from db by uri""" db = mlrun.get_run_db() project, name, tag, uid = parse_feature_set_uri(uri, project) - resource = mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( - project, "feature-set" + resource = ( + mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( + project, "feature-set" + ) ) auth_input = AuthorizationVerificationInput( @@ -113,8 +115,10 @@ def get_feature_vector_by_uri(uri, project=None, update=True): project, name, tag, uid = parse_versioned_object_uri(uri, default_project) - resource = mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( - project, "feature-vector" + resource = ( + mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( + project, "feature-vector" + ) ) if update: @@ -136,8 +140,10 @@ def verify_feature_set_permissions( ): project, _, _, _ = parse_feature_set_uri(feature_set.uri) - resource = mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( - project, "feature-set" + resource = ( + mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( + project, "feature-set" + ) ) db = feature_set._get_run_db() @@ -162,8 +168,10 @@ def verify_feature_vector_permissions( ): project = feature_vector._metadata.project or mlconf.default_project - resource = mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( - project, "feature-vector" + resource = ( + mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( + project, "feature-vector" + ) ) db = mlrun.get_run_db() diff --git a/mlrun/feature_store/feature_set.py b/mlrun/feature_store/feature_set.py index ed364c6cb6..1e3e26e452 100644 --- a/mlrun/feature_store/feature_set.py +++ b/mlrun/feature_store/feature_set.py @@ -314,7 +314,8 @@ def fullname(self): return fullname def _override_run_db( - self, session, + self, + session, ): # Import here, since this method only runs in API context. If this import was global, client would need # API requirements and would fail. @@ -382,9 +383,9 @@ def set_targets( self.spec.graph.final_step = default_final_step def purge_targets(self, target_names: List[str] = None, silent: bool = False): - """ Delete data of specific targets + """Delete data of specific targets :param target_names: List of names of targets to delete (default: delete all ingested targets) - :param silent: Fail silently if target doesn't exist in featureset status """ + :param silent: Fail silently if target doesn't exist in featureset status""" verify_feature_set_permissions( self, mlrun.api.schemas.AuthorizationAction.delete @@ -752,7 +753,8 @@ def do(self, event): aggs.append(agg) window_column = funcs.window(time_column, spark_window, spark_period) df = input_df.groupBy( - *self.key_columns, window_column.end.alias(time_column), + *self.key_columns, + window_column.end.alias(time_column), ).agg(*aggs) df = df.withColumn(f"{time_column}_window", funcs.lit(window)) dfs.append(df) diff --git a/mlrun/feature_store/ingestion.py b/mlrun/feature_store/ingestion.py index 9654c3c63c..830eedd286 100644 --- a/mlrun/feature_store/ingestion.py +++ b/mlrun/feature_store/ingestion.py @@ -125,7 +125,11 @@ def featureset_initializer(server): featureset, source, targets, _, _ = context_to_ingestion_params(context) graph = featureset.spec.graph.copy() _add_data_steps( - graph, cache, featureset, targets=targets, source=source, + graph, + cache, + featureset, + targets=targets, + source=source, ) featureset.save() server.graph = graph diff --git a/mlrun/feature_store/retrieval/base.py b/mlrun/feature_store/retrieval/base.py index 28c38ae82b..fcb14db87f 100644 --- a/mlrun/feature_store/retrieval/base.py +++ b/mlrun/feature_store/retrieval/base.py @@ -144,20 +144,31 @@ def merge( merge_func = self._join merged_df = merge_func( - merged_df, entity_timestamp_column, featureset, featureset_df, + merged_df, + entity_timestamp_column, + featureset, + featureset_df, ) self._result_df = merged_df @abc.abstractmethod def _asof_join( - self, entity_df, entity_timestamp_column: str, featureset, featureset_df, + self, + entity_df, + entity_timestamp_column: str, + featureset, + featureset_df, ): raise NotImplementedError("_asof_join() operation not implemented in class") @abc.abstractmethod def _join( - self, entity_df, entity_timestamp_column: str, featureset, featureset_df, + self, + entity_df, + entity_timestamp_column: str, + featureset, + featureset_df, ): raise NotImplementedError("_join() operation not implemented in class") diff --git a/mlrun/feature_store/retrieval/dask_merger.py b/mlrun/feature_store/retrieval/dask_merger.py index d4f66bb051..e667f97cfc 100644 --- a/mlrun/feature_store/retrieval/dask_merger.py +++ b/mlrun/feature_store/retrieval/dask_merger.py @@ -99,7 +99,11 @@ def _asof_join( ) merged_df = merge_asof( - entity_df, featureset_df, left_index=True, right_index=True, by=indexes, + entity_df, + featureset_df, + left_index=True, + right_index=True, + by=indexes, ) return merged_df diff --git a/mlrun/feature_store/retrieval/local_merger.py b/mlrun/feature_store/retrieval/local_merger.py index 169f1eb27e..3c59b5bad7 100644 --- a/mlrun/feature_store/retrieval/local_merger.py +++ b/mlrun/feature_store/retrieval/local_merger.py @@ -52,7 +52,8 @@ def _generate_vector( ) else: df = feature_set.to_dataframe( - columns=column_names, time_column=entity_timestamp_column, + columns=column_names, + time_column=entity_timestamp_column, ) # rename columns with aliases df.rename( diff --git a/mlrun/feature_store/retrieval/online.py b/mlrun/feature_store/retrieval/online.py index 620715bead..5a15edf639 100644 --- a/mlrun/feature_store/retrieval/online.py +++ b/mlrun/feature_store/retrieval/online.py @@ -20,7 +20,10 @@ def _build_feature_vector_graph( - vector, feature_set_fields, feature_set_objects, fixed_window_type, + vector, + feature_set_fields, + feature_set_objects, + fixed_window_type, ): graph = vector.spec.graph.copy() start_states, default_final_state, responders = graph.check_and_process_graph( diff --git a/mlrun/feature_store/retrieval/spark_merger.py b/mlrun/feature_store/retrieval/spark_merger.py index 1135ad3183..c7bd48f842 100644 --- a/mlrun/feature_store/retrieval/spark_merger.py +++ b/mlrun/feature_store/retrieval/spark_merger.py @@ -99,7 +99,11 @@ def _generate_vector( return OfflineVectorResponse(self) def _asof_join( - self, entity_df, entity_timestamp_column: str, featureset, featureset_df, + self, + entity_df, + entity_timestamp_column: str, + featureset, + featureset_df, ): """Perform an as of join between entity and featureset. @@ -166,7 +170,11 @@ def _asof_join( return filter_most_recent_feature_timestamp.drop("_row_nr", "_rank") def _join( - self, entity_df, entity_timestamp_column: str, featureset, featureset_df, + self, + entity_df, + entity_timestamp_column: str, + featureset, + featureset_df, ): """ diff --git a/mlrun/frameworks/_common/mlrun_interface.py b/mlrun/frameworks/_common/mlrun_interface.py index c554ce8899..93737d491c 100644 --- a/mlrun/frameworks/_common/mlrun_interface.py +++ b/mlrun/frameworks/_common/mlrun_interface.py @@ -68,7 +68,8 @@ def add_interface( # Add the MLRun properties: cls._insert_properties( - obj=obj, properties=restoration_information[0], + obj=obj, + properties=restoration_information[0], ) # Replace the object's properties in MLRun's properties: @@ -158,7 +159,9 @@ def is_applied(cls, obj: MLRunInterfaceableType) -> bool: @classmethod def _insert_properties( - cls, obj: MLRunInterfaceableType, properties: Dict[str, Any] = None, + cls, + obj: MLRunInterfaceableType, + properties: Dict[str, Any] = None, ): """ Insert the properties of the interface to the object. The properties default values are being copied (not deep diff --git a/mlrun/frameworks/_common/model_handler.py b/mlrun/frameworks/_common/model_handler.py index e547aa6a9d..64cd58ac78 100644 --- a/mlrun/frameworks/_common/model_handler.py +++ b/mlrun/frameworks/_common/model_handler.py @@ -371,7 +371,9 @@ def set_parameters( self._parameters.pop(label) def set_extra_data( - self, to_add: Dict[str, ExtraDataType] = None, to_remove: List[str] = None, + self, + to_add: Dict[str, ExtraDataType] = None, + to_remove: List[str] = None, ): """ Update the extra data dictionary of this model artifact. @@ -1038,7 +1040,8 @@ def _log_custom_objects(self) -> Dict[str, Artifact]: return artifacts def _read_io_samples( - self, samples: Union[IOSampleType, List[IOSampleType]], + self, + samples: Union[IOSampleType, List[IOSampleType]], ) -> List[Feature]: """ Read the given inputs / output sample to / from the model into a list of MLRun Features (ports) to log in diff --git a/mlrun/frameworks/_dl_common/loggers/mlrun_logger.py b/mlrun/frameworks/_dl_common/loggers/mlrun_logger.py index 5e27f5a52d..e2cddcb447 100644 --- a/mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +++ b/mlrun/frameworks/_dl_common/loggers/mlrun_logger.py @@ -39,7 +39,8 @@ class _Loops: EVALUATION = "evaluation" def __init__( - self, context: mlrun.MLClientCtx, + self, + context: mlrun.MLClientCtx, ): """ Initialize the MLRun logging interface to work with the given context. @@ -53,7 +54,8 @@ def __init__( self._artifacts = {} # type: Dict[str, Artifact] def log_epoch_to_context( - self, epoch: int, + self, + epoch: int, ): """ Log the last epoch. The last epoch information recorded in the given tracking dictionaries will be logged, @@ -164,7 +166,8 @@ def log_run( ) # Log the artifact: self._context.log_artifact( - artifact, local_path=artifact.key, + artifact, + local_path=artifact.key, ) # Collect it for later adding it to the model logging as extra data: self._artifacts[artifact.key.split(".")[0]] = artifact @@ -177,7 +180,8 @@ def log_run( ) # Log the artifact: self._context.log_artifact( - artifact, local_path=artifact.key, + artifact, + local_path=artifact.key, ) # Collect it for later adding it to the model logging as extra data: self._artifacts[artifact.key.split(".")[0]] = artifact @@ -301,7 +305,9 @@ def _generate_summary_results_artifact( # Add titles: summary_figure.update_layout( - title=f"{name} Summary", xaxis_title="Epochs", yaxis_title="Results", + title=f"{name} Summary", + xaxis_title="Epochs", + yaxis_title="Results", ) # Draw the results: @@ -348,7 +354,9 @@ def _generate_dynamic_hyperparameter_values_artifact( # Add titles: hyperparameter_figure.update_layout( - title=name, xaxis_title="Epochs", yaxis_title="Values", + title=name, + xaxis_title="Epochs", + yaxis_title="Values", ) # Draw the values: diff --git a/mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py b/mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py index c0f274dc4f..fbe33be0f2 100644 --- a/mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +++ b/mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py @@ -433,8 +433,10 @@ def _create_output_path(self): # If the tensorboard directory is not provided, set it to the default: if self._tensorboard_directory is None: # Use the default tensorboard logs directory: - self._tensorboard_directory = mlrun.mlconf.default_tensorboard_logs_path.replace( - "{{project}}", self._context.project + self._tensorboard_directory = ( + mlrun.mlconf.default_tensorboard_logs_path.replace( + "{{project}}", self._context.project + ) ) # Try to create the directory, if not succeeded (writing error) change to the artifacts path: try: @@ -617,12 +619,14 @@ def _generate_context_link( :return: The generated link. """ - return '{}'.format( - config.resolve_ui_url(), - config.ui.projects_prefix, - context.project, - context.uid, - link_text, + return ( + '{}'.format( + config.resolve_ui_url(), + config.ui.projects_prefix, + context.project, + context.uid, + link_text, + ) ) @staticmethod diff --git a/mlrun/frameworks/_ml_common/logger.py b/mlrun/frameworks/_ml_common/logger.py index 62127cb014..b15e39695c 100644 --- a/mlrun/frameworks/_ml_common/logger.py +++ b/mlrun/frameworks/_ml_common/logger.py @@ -184,7 +184,8 @@ def log_results( self._context.commit(completed=False) def log_run( - self, model_handler: MLModelHandler, + self, + model_handler: MLModelHandler, ): """ End the logger's run, logging the collected artifacts and metrics results with the model. The model will be @@ -196,11 +197,13 @@ def log_run( # model artifact: if self._mode == LoggerMode.TRAINING: model_handler.log( - metrics=self._logged_results, artifacts=self._logged_artifacts, + metrics=self._logged_results, + artifacts=self._logged_artifacts, ) else: model_handler.update( - metrics=self._logged_results, artifacts=self._logged_artifacts, + metrics=self._logged_results, + artifacts=self._logged_artifacts, ) # Commit: diff --git a/mlrun/frameworks/_ml_common/metrics_library.py b/mlrun/frameworks/_ml_common/metrics_library.py index f3df94abe3..583f3a76fc 100644 --- a/mlrun/frameworks/_ml_common/metrics_library.py +++ b/mlrun/frameworks/_ml_common/metrics_library.py @@ -161,7 +161,10 @@ def default(cls, model: ModelType, y: DatasetType = None, **kwargs) -> List[Metr return metrics @staticmethod - def _to_metric_class(metric_entry: MetricEntry, metric_name: str = None,) -> Metric: + def _to_metric_class( + metric_entry: MetricEntry, + metric_name: str = None, + ) -> Metric: """ Create a Metric instance from a user given metric entry. diff --git a/mlrun/frameworks/_ml_common/mlrun_interface.py b/mlrun/frameworks/_ml_common/mlrun_interface.py index a21e12a2e5..17ec18feb0 100644 --- a/mlrun/frameworks/_ml_common/mlrun_interface.py +++ b/mlrun/frameworks/_ml_common/mlrun_interface.py @@ -23,12 +23,12 @@ class MLMLRunInterface(MLRunInterface, ABC): _PROPERTIES = { # A model handler instance with the model for logging / updating the model (if not provided the model won't be # logged / updated at the end of training / testing): - "_model_handler": None, # type: MLModelHandler + "_model_handler": None, # > type: MLModelHandler # The logger that is logging this model's training / evaluation: - "_logger": None, # type: Logger + "_logger": None, # > type: Logger # The test set (For validation post training or evaluation post prediction): - "_x_test": None, # type: DatasetType - "_y_test": None, # type: DatasetType + "_x_test": None, # > type: DatasetType + "_y_test": None, # > type: DatasetType } _METHODS = [ "set_model_handler", @@ -44,7 +44,9 @@ class MLMLRunInterface(MLRunInterface, ABC): @classmethod def add_interface( - cls, obj: ModelType, restoration_information: RestorationInformation = None, + cls, + obj: ModelType, + restoration_information: RestorationInformation = None, ): """ Enrich the object with this interface properties, methods and functions so it will have this framework MLRun's @@ -218,7 +220,10 @@ def _post_fit(self, x: DatasetType, y: DatasetType = None): ) y_pred = self.predict(self._x_test) self._post_predict( - x=self._x_test, y=self._y_test, y_pred=y_pred, is_predict_proba=False, + x=self._x_test, + y=self._y_test, + y_pred=y_pred, + is_predict_proba=False, ) # Log the model with the given attributes: diff --git a/mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py b/mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py index b57b394a91..1c5e528307 100644 --- a/mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +++ b/mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py @@ -19,7 +19,10 @@ class CalibrationCurvePlan(MLPlotPlan): _ARTIFACT_NAME = "calibration-curve" def __init__( - self, normalize: bool = False, n_bins: int = 5, strategy: str = "uniform", + self, + normalize: bool = False, + n_bins: int = 5, + strategy: str = "uniform", ): """ Initialize a calibration curve plan with the given configuration. @@ -120,7 +123,8 @@ def produce( # Creating the artifact: self._artifacts[self._ARTIFACT_NAME] = PlotlyArtifact( - key=self._ARTIFACT_NAME, figure=fig, + key=self._ARTIFACT_NAME, + figure=fig, ) return self._artifacts diff --git a/mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py b/mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py index 308b342660..5796f7ce67 100644 --- a/mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +++ b/mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py @@ -109,7 +109,9 @@ def produce( ) # Add title: - figure.update_layout(title_text="Confusion matrix",) + figure.update_layout( + title_text="Confusion matrix", + ) # Add custom x-axis title: figure.add_annotation( @@ -149,7 +151,8 @@ def produce( # Create the plot's artifact: self._artifacts[self._ARTIFACT_NAME] = PlotlyArtifact( - key=self._ARTIFACT_NAME, figure=figure, + key=self._ARTIFACT_NAME, + figure=figure, ) return self._artifacts diff --git a/mlrun/frameworks/_ml_common/plans/feature_importance_plan.py b/mlrun/frameworks/_ml_common/plans/feature_importance_plan.py index cec16b3cc3..bcd75682e4 100644 --- a/mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +++ b/mlrun/frameworks/_ml_common/plans/feature_importance_plan.py @@ -80,7 +80,8 @@ def produce( # Creating the artifact: self._artifacts[self._ARTIFACT_NAME] = PlotlyArtifact( - key=self._ARTIFACT_NAME, figure=fig, + key=self._ARTIFACT_NAME, + figure=fig, ) return self._artifacts diff --git a/mlrun/frameworks/_ml_common/plans/roc_curve_plan.py b/mlrun/frameworks/_ml_common/plans/roc_curve_plan.py index bee2c273d4..4a27437f92 100644 --- a/mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +++ b/mlrun/frameworks/_ml_common/plans/roc_curve_plan.py @@ -145,7 +145,8 @@ def produce( # Creating the plot artifact: self._artifacts[self._ARTIFACT_NAME] = PlotlyArtifact( - key=self._ARTIFACT_NAME, figure=fig, + key=self._ARTIFACT_NAME, + figure=fig, ) return self._artifacts diff --git a/mlrun/frameworks/onnx/mlrun_interface.py b/mlrun/frameworks/onnx/mlrun_interface.py index fab302cb1a..1b16951297 100644 --- a/mlrun/frameworks/onnx/mlrun_interface.py +++ b/mlrun/frameworks/onnx/mlrun_interface.py @@ -43,7 +43,8 @@ def __init__( # initialize the onnx run time session: self._inference_session = onnxruntime.InferenceSession( - onnx._serialize(model), providers=self._execution_providers, + onnx._serialize(model), + providers=self._execution_providers, ) # Get the input layers names: diff --git a/mlrun/frameworks/parallel_coordinates.py b/mlrun/frameworks/parallel_coordinates.py index b398fd1562..b0c329fcc7 100644 --- a/mlrun/frameworks/parallel_coordinates.py +++ b/mlrun/frameworks/parallel_coordinates.py @@ -25,7 +25,11 @@ def gen_bool_list(col): return [name == col for name in output_cols] buttons = [ - dict(label=col, method="update", args=[{"visible": gen_bool_list(col)}],) + dict( + label=col, + method="update", + args=[{"visible": gen_bool_list(col)}], + ) for col in output_cols ] diff --git a/mlrun/frameworks/pytorch/callbacks/logging_callback.py b/mlrun/frameworks/pytorch/callbacks/logging_callback.py index ad73242750..61ed2765fa 100644 --- a/mlrun/frameworks/pytorch/callbacks/logging_callback.py +++ b/mlrun/frameworks/pytorch/callbacks/logging_callback.py @@ -237,7 +237,10 @@ def on_epoch_end(self, epoch: int): if self._dynamic_hyperparameters_keys: for ( parameter_name, - (source, key_chain,), + ( + source, + key_chain, + ), ) in self._dynamic_hyperparameters_keys.items(): self._logger.log_dynamic_hyperparameter( parameter_name=parameter_name, @@ -265,7 +268,9 @@ def on_train_end(self): # Store the last training metrics results of this epoch: for metric_function in self._objects[self._ObjectKeys.METRIC_FUNCTIONS]: - metric_name = self._get_metric_name(metric_function=metric_function,) + metric_name = self._get_metric_name( + metric_function=metric_function, + ) self._logger.log_training_summary( metric_name=metric_name, result=float(self._logger.training_results[metric_name][-1][-1]), @@ -303,7 +308,9 @@ def on_validation_end( self._objects[self._ObjectKeys.METRIC_FUNCTIONS], metric_values ): self._logger.log_validation_summary( - metric_name=self._get_metric_name(metric_function=metric_function,), + metric_name=self._get_metric_name( + metric_function=metric_function, + ), result=float(metric_value), ) @@ -367,7 +374,9 @@ def on_train_metrics_end(self, metric_values: List[MetricValueType]): self._objects[self._ObjectKeys.METRIC_FUNCTIONS], metric_values ): self._logger.log_training_result( - metric_name=self._get_metric_name(metric_function=metric_function,), + metric_name=self._get_metric_name( + metric_function=metric_function, + ), result=float(metric_value), ) @@ -381,7 +390,9 @@ def on_validation_metrics_end(self, metric_values: List[MetricValueType]): self._objects[self._ObjectKeys.METRIC_FUNCTIONS], metric_values ): self._logger.log_validation_result( - metric_name=self._get_metric_name(metric_function=metric_function,), + metric_name=self._get_metric_name( + metric_function=metric_function, + ), result=float(metric_value), ) diff --git a/mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py b/mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py index 508b3780b4..c84e13e3e4 100644 --- a/mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +++ b/mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py @@ -109,7 +109,8 @@ def write_model_to_tensorboard(self, model: Module, input_sample: Tensor): :param input_sample: An input sample for writing the model. """ self._summary_writer.add_graph( - model=model, input_to_model=input_sample, + model=model, + input_to_model=input_sample, ) def write_parameters_table_to_tensorboard(self): @@ -159,7 +160,9 @@ def _write_text_to_tensorboard(self, tag: str, text: str, step: int): :param step: The iteration / epoch the text belongs to. """ self._summary_writer.add_text( - tag=tag, text_string=text, global_step=step, + tag=tag, + text_string=text, + global_step=step, ) def _write_scalar_to_tensorboard(self, name: str, value: float, step: int): @@ -171,7 +174,9 @@ def _write_scalar_to_tensorboard(self, name: str, value: float, step: int): :param step: The iteration / epoch the value belongs to. """ self._summary_writer.add_scalar( - tag=name, scalar_value=value, global_step=step, + tag=name, + scalar_value=value, + global_step=step, ) def _write_weight_histogram_to_tensorboard( @@ -185,7 +190,9 @@ def _write_weight_histogram_to_tensorboard( :param step: The iteration / epoch the weight's histogram state belongs to. """ self._summary_writer.add_histogram( - tag=name, values=weight, global_step=step, + tag=name, + values=weight, + global_step=step, ) def _write_weight_image_to_tensorboard( diff --git a/mlrun/frameworks/pytorch/callbacks_handler.py b/mlrun/frameworks/pytorch/callbacks_handler.py index adab97f748..5e5acee134 100644 --- a/mlrun/frameworks/pytorch/callbacks_handler.py +++ b/mlrun/frameworks/pytorch/callbacks_handler.py @@ -357,7 +357,11 @@ def on_validation_batch_end( y_true=y_true, ) - def on_inference_begin(self, x, callbacks: List[str] = None,) -> bool: + def on_inference_begin( + self, + x, + callbacks: List[str] = None, + ) -> bool: """ Call the 'on_inference_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all callbacks will be called. @@ -374,7 +378,10 @@ def on_inference_begin(self, x, callbacks: List[str] = None,) -> bool: ) def on_inference_end( - self, y_pred: Tensor, y_true: Tensor, callbacks: List[str] = None, + self, + y_pred: Tensor, + y_true: Tensor, + callbacks: List[str] = None, ) -> bool: """ Call the 'on_inference_end' method of every callback in the callbacks list. If the list is 'None' (not given), diff --git a/mlrun/frameworks/pytorch/mlrun_interface.py b/mlrun/frameworks/pytorch/mlrun_interface.py index e6c488c309..f20abad270 100644 --- a/mlrun/frameworks/pytorch/mlrun_interface.py +++ b/mlrun/frameworks/pytorch/mlrun_interface.py @@ -789,7 +789,10 @@ def _validate( # End of batch callbacks: if not self._callbacks_handler.on_validation_batch_end( - batch=batch, x=x, y_pred=y_pred, y_true=y_true, + batch=batch, + x=x, + y_pred=y_pred, + y_true=y_true, ): break diff --git a/mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py b/mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py index d38b3753e1..7bd70cf0ce 100644 --- a/mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +++ b/mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py @@ -164,7 +164,9 @@ def _write_text_to_tensorboard(self, tag: str, text: str, step: int): """ with self._file_writer.as_default(): tf.summary.text( - name=tag, data=text, step=step, + name=tag, + data=text, + step=step, ) def _write_scalar_to_tensorboard(self, name: str, value: float, step: int): @@ -177,7 +179,9 @@ def _write_scalar_to_tensorboard(self, name: str, value: float, step: int): """ with self._file_writer.as_default(): tf.summary.scalar( - name=name, data=value, step=step, + name=name, + data=value, + step=step, ) def _write_weight_histogram_to_tensorboard( @@ -192,7 +196,9 @@ def _write_weight_histogram_to_tensorboard( """ with self._file_writer.as_default(): tf.summary.histogram( - name=name, data=weight, step=step, + name=name, + data=weight, + step=step, ) def _write_weight_image_to_tensorboard( diff --git a/mlrun/frameworks/tf_keras/mlrun_interface.py b/mlrun/frameworks/tf_keras/mlrun_interface.py index 5ff7337670..5f368c93a4 100644 --- a/mlrun/frameworks/tf_keras/mlrun_interface.py +++ b/mlrun/frameworks/tf_keras/mlrun_interface.py @@ -1,8 +1,7 @@ import importlib import os from abc import ABC -from types import ModuleType -from typing import List, Set, Tuple, Union +from typing import List, Tuple, Union import tensorflow as tf from tensorflow import keras @@ -33,11 +32,11 @@ class TFKerasMLRunInterface(MLRunInterface, ABC): # Attributes to be inserted so the MLRun interface will be fully enabled. _PROPERTIES = { # Logging callbacks list: - "_logging_callbacks": set(), # type: Set[Callback] + "_logging_callbacks": set(), # > type: Set[Callback] # Variable to hold the horovod module: - "_hvd": None, # type: ModuleType + "_hvd": None, # > type: ModuleType # List of all the callbacks that should only be applied on rank 0 when using horovod: - "_RANK_0_ONLY_CALLBACKS": { # type: Set[str] + "_RANK_0_ONLY_CALLBACKS": { # > type: Set[str] "LoggingCallback", "MLRunLoggingCallback", "TensorboardLoggingCallback", @@ -62,7 +61,9 @@ class TFKerasMLRunInterface(MLRunInterface, ABC): @classmethod def add_interface( - cls, obj: keras.Model, restoration_information: RestorationInformation = None, + cls, + obj: keras.Model, + restoration_information: RestorationInformation = None, ): """ Enrich the object with this interface properties, methods and functions so it will have this framework MLRun's @@ -171,7 +172,8 @@ def mlrun_evaluate(self, *args, **kwargs): # Call the pre evaluate method: (callbacks, steps) = self._pre_evaluate( - callbacks=kwargs["callbacks"], steps=kwargs["steps"], + callbacks=kwargs["callbacks"], + steps=kwargs["steps"], ) # Assign parameters: @@ -343,7 +345,9 @@ def _pre_fit( return callbacks, verbose, steps_per_epoch, validation_steps def _pre_evaluate( - self, callbacks: List[Callback], steps: Union[int, None], + self, + callbacks: List[Callback], + steps: Union[int, None], ) -> Tuple[List[Callback], Union[int, None]]: """ Method to call before calling 'evaluate' to setup the run and inputs for using horovod. diff --git a/mlrun/k8s_utils.py b/mlrun/k8s_utils.py index 64243d845b..27d4ba22bc 100644 --- a/mlrun/k8s_utils.py +++ b/mlrun/k8s_utils.py @@ -296,7 +296,8 @@ def create_project_service_account(self, project, service_account, namespace="") ) try: api_response = self.v1api.create_namespaced_service_account( - namespace, k8s_service_account, + namespace, + k8s_service_account, ) return api_response except ApiException as exc: @@ -514,7 +515,10 @@ def mount_secret(self, name, path="/secret", items=None, sub_path=None): self.add_volume( client.V1Volume( name=name, - secret=client.V1SecretVolumeSource(secret_name=name, items=items,), + secret=client.V1SecretVolumeSource( + secret_name=name, + items=items, + ), ), mount_path=path, sub_path=sub_path, @@ -558,7 +562,7 @@ def _get_spec(self, template=False): def format_labels(labels): - """ Convert a dictionary of labels into a comma separated string """ + """Convert a dictionary of labels into a comma separated string""" if labels: return ",".join([f"{k}={v}" for k, v in labels.items()]) else: diff --git a/mlrun/model.py b/mlrun/model.py index 2546bf25b8..6ac7333393 100644 --- a/mlrun/model.py +++ b/mlrun/model.py @@ -259,7 +259,8 @@ class Credentials(ModelObj): generate_access_key = "$generate" def __init__( - self, access_key=None, + self, + access_key=None, ): self.access_key = access_key @@ -925,8 +926,7 @@ def NewTask( secrets=None, base=None, ): - """Creates a new task - see new_task - """ + """Creates a new task - see new_task""" warnings.warn( "NewTask will be deprecated in 0.7.0, and will be removed in 0.9.0, use new_task instead", # TODO: In 0.7.0 and replace NewTask to new_task in examples & demos @@ -1158,7 +1158,11 @@ class DataTarget(DataTargetBase): ] def __init__( - self, kind: str = None, name: str = "", path=None, online=None, + self, + kind: str = None, + name: str = "", + path=None, + online=None, ): super().__init__(kind, name, path) self.status = "" diff --git a/mlrun/model_monitoring/stream_processing_fs.py b/mlrun/model_monitoring/stream_processing_fs.py index 5dfbfe217b..60949cbbff 100644 --- a/mlrun/model_monitoring/stream_processing_fs.py +++ b/mlrun/model_monitoring/stream_processing_fs.py @@ -110,8 +110,10 @@ def __init__( ) self.tsdb_path = f"{self.tsdb_container}/{self.tsdb_path}" - self.parquet_path = config.model_endpoint_monitoring.store_prefixes.user_space.format( - project=project, kind="parquet" + self.parquet_path = ( + config.model_endpoint_monitoring.store_prefixes.user_space.format( + project=project, kind="parquet" + ) ) logger.info( @@ -161,7 +163,10 @@ def create_feature_set(self): step_name="Aggregates", ) feature_set.add_aggregation( - LATENCY, ["avg"], self.aggregate_avg_windows, self.aggregate_avg_period, + LATENCY, + ["avg"], + self.aggregate_avg_windows, + self.aggregate_avg_period, ) feature_set.graph.add_step( "storey.steps.SampleWindow", @@ -417,7 +422,8 @@ def do(self, full_event): versioned_model = f"{model}:{version}" if version else f"{model}:latest" endpoint_id = create_model_endpoint_id( - function_uri=function_uri, versioned_model=versioned_model, + function_uri=function_uri, + versioned_model=versioned_model, ) endpoint_id = str(endpoint_id) @@ -442,23 +448,44 @@ def do(self, full_event): features = event.get("request", {}).get("inputs") predictions = event.get("resp", {}).get("outputs") - if not self.is_valid(endpoint_id, is_not_none, timestamp, ["when"],): + if not self.is_valid( + endpoint_id, + is_not_none, + timestamp, + ["when"], + ): return None if endpoint_id not in self.first_request: self.first_request[endpoint_id] = timestamp self.last_request[endpoint_id] = timestamp - if not self.is_valid(endpoint_id, is_not_none, request_id, ["request", "id"],): + if not self.is_valid( + endpoint_id, + is_not_none, + request_id, + ["request", "id"], + ): return None - if not self.is_valid(endpoint_id, is_not_none, latency, ["microsec"],): + if not self.is_valid( + endpoint_id, + is_not_none, + latency, + ["microsec"], + ): return None if not self.is_valid( - endpoint_id, is_not_none, features, ["request", "inputs"], + endpoint_id, + is_not_none, + features, + ["request", "inputs"], ): return None if not self.is_valid( - endpoint_id, is_not_none, predictions, ["resp", "outputs"], + endpoint_id, + is_not_none, + predictions, + ["resp", "outputs"], ): return None @@ -562,7 +589,8 @@ def enrich_even_details(event) -> Optional[dict]: versioned_model = f"{model}:{version}" if version else f"{model}:latest" endpoint_id = create_model_endpoint_id( - function_uri=function_uri, versioned_model=versioned_model, + function_uri=function_uri, + versioned_model=versioned_model, ) endpoint_id = str(endpoint_id) diff --git a/mlrun/platforms/iguazio.py b/mlrun/platforms/iguazio.py index d5785ae38c..e76feeeb95 100644 --- a/mlrun/platforms/iguazio.py +++ b/mlrun/platforms/iguazio.py @@ -48,7 +48,10 @@ def xcp_op( args = ["-r"] + args return dsl.ContainerOp( - name="xcp", image="yhaviv/invoke", command=["xcp"], arguments=args, + name="xcp", + image="yhaviv/invoke", + command=["xcp"], + arguments=args, ) diff --git a/mlrun/platforms/other.py b/mlrun/platforms/other.py index b8856ac920..631472c1a0 100644 --- a/mlrun/platforms/other.py +++ b/mlrun/platforms/other.py @@ -75,7 +75,9 @@ def auto_mount(pvc_name="", volume_mount_path="", volume_name=None): volume_name=volume_name or "shared-persistency", ) if "MLRUN_PVC_MOUNT" in os.environ: - return mount_pvc(volume_name=volume_name or "shared-persistency",) + return mount_pvc( + volume_name=volume_name or "shared-persistency", + ) # In the case of MLRun-kit when working remotely, no env variables will be defined but auto-mount # parameters may still be declared - use them in that case. if config.storage.auto_mount_type == "pvc": diff --git a/mlrun/projects/pipelines.py b/mlrun/projects/pipelines.py index aadfbb9186..125dbf5b55 100644 --- a/mlrun/projects/pipelines.py +++ b/mlrun/projects/pipelines.py @@ -349,7 +349,10 @@ def save(cls, project, workflow_spec: WorkflowSpec, target, artifact_path=None): workflow_file = workflow_spec.get_source_file(project.spec.context) functions = FunctionsDict(project) pipeline = create_pipeline( - project, workflow_file, functions, secrets=project._secrets, + project, + workflow_file, + functions, + secrets=project._secrets, ) artifact_path = artifact_path or project.spec.artifact_path @@ -384,7 +387,10 @@ def run( ttl=workflow_spec.ttl, ) project.notifiers.push_start_message( - project.metadata.name, project.get_param("commit_id", None), id, True, + project.metadata.name, + project.get_param("commit_id", None), + id, + True, ) pipeline_context.clear() return _PipelineRunStatus(id, cls, project=project, workflow=workflow_spec) diff --git a/mlrun/projects/project.py b/mlrun/projects/project.py index c8d6a478de..06b67c4b35 100644 --- a/mlrun/projects/project.py +++ b/mlrun/projects/project.py @@ -1436,7 +1436,7 @@ def get_function( return function def get_function_objects(self) -> typing.Dict[str, mlrun.runtimes.BaseRuntime]: - """"get a virtual dict with all the project functions ready for use in a pipeline""" + """ "get a virtual dict with all the project functions ready for use in a pipeline""" self.sync_functions() return FunctionsDict(self) @@ -1845,7 +1845,7 @@ def export(self, filepath=None): fp.write(self.to_yaml()) def set_model_monitoring_credentials(self, access_key: str): - """ Set the credentials that will be used by the project's model monitoring + """Set the credentials that will be used by the project's model monitoring infrastructure functions. The supplied credentials must have data access diff --git a/mlrun/runtimes/base.py b/mlrun/runtimes/base.py index d37ae4035b..c55ca0aeb7 100644 --- a/mlrun/runtimes/base.py +++ b/mlrun/runtimes/base.py @@ -1110,11 +1110,21 @@ def delete_resources( crd_group, crd_version, crd_plural = self._get_crd_info() if crd_group and crd_version and crd_plural: deleted_resources = self._delete_crd_resources( - db, db_session, namespace, label_selector, force, grace_period, + db, + db_session, + namespace, + label_selector, + force, + grace_period, ) else: deleted_resources = self._delete_pod_resources( - db, db_session, namespace, label_selector, force, grace_period, + db, + db_session, + namespace, + label_selector, + force, + grace_period, ) self._delete_resources( db, @@ -1381,7 +1391,10 @@ def _resolve_label_selector( return label_selector def _wait_for_pods_deletion( - self, namespace: str, deleted_pods: List[Dict], label_selector: str = None, + self, + namespace: str, + deleted_pods: List[Dict], + label_selector: str = None, ): k8s_helper = get_k8s_helper() deleted_pod_names = [pod_dict["metadata"]["name"] for pod_dict in deleted_pods] @@ -1415,7 +1428,9 @@ def _verify_pods_removed(): ) def _wait_for_crds_underlying_pods_deletion( - self, deleted_crds: List[Dict], label_selector: str = None, + self, + deleted_crds: List[Dict], + label_selector: str = None, ): # we're using here the run identifier as the common ground to identify which pods are relevant to which CRD, so # if they are not coupled we are not able to wait - simply return @@ -1594,7 +1609,10 @@ def _delete_crd_resources( try: self._pre_deletion_runtime_resource_run_actions( - db, db_session, crd_object, desired_run_state, + db, + db_session, + crd_object, + desired_run_state, ) except Exception as exc: # Don't prevent the deletion for failure in the pre deletion run actions @@ -1648,7 +1666,10 @@ def _pre_deletion_runtime_resource_run_actions( self._ensure_run_logs_collected(db, db_session, project, uid) def _is_runtime_resource_run_in_terminal_state( - self, db: DBInterface, db_session: Session, runtime_resource: Dict, + self, + db: DBInterface, + db_session: Session, + runtime_resource: Dict, ) -> Tuple[bool, Optional[datetime]]: """ A runtime can have different underlying resources (like pods or CRDs) - to generalize we call it runtime @@ -1677,7 +1698,9 @@ def _is_runtime_resource_run_in_terminal_state( return True, last_update def _list_runs_for_monitoring( - self, db: DBInterface, db_session: Session, + self, + db: DBInterface, + db_session: Session, ): runs = db.list_runs(db_session, project="*") project_run_uid_map = {} @@ -1742,16 +1765,27 @@ def _monitor_runtime_resource( return run = project_run_uid_map.get(project, {}).get(uid) if runtime_resource_is_crd: - (_, _, run_state,) = self._resolve_crd_object_status_info( - db, db_session, runtime_resource - ) + ( + _, + _, + run_state, + ) = self._resolve_crd_object_status_info(db, db_session, runtime_resource) else: - (_, _, run_state,) = self._resolve_pod_status_info( - db, db_session, runtime_resource - ) + ( + _, + _, + run_state, + ) = self._resolve_pod_status_info(db, db_session, runtime_resource) self._update_ui_url(db, db_session, project, uid, runtime_resource, run) _, updated_run_state = self._ensure_run_state( - db, db_session, project, uid, name, run_state, run, search_run=False, + db, + db_session, + project, + uid, + name, + run_state, + run, + search_run=False, ) if updated_run_state in RunStates.terminal_states(): self._ensure_run_logs_collected(db, db_session, project, uid) @@ -1998,7 +2032,11 @@ def _delete_crd(namespace, crd_group, crd_version, crd_plural, crd_object): name = crd_object["metadata"]["name"] try: k8s_helper.crdapi.delete_namespaced_custom_object( - crd_group, crd_version, namespace, crd_plural, name, + crd_group, + crd_version, + namespace, + crd_plural, + name, ) logger.info( "Deleted crd object", diff --git a/mlrun/runtimes/function.py b/mlrun/runtimes/function.py index 4b11e497c6..8faaa6066a 100644 --- a/mlrun/runtimes/function.py +++ b/mlrun/runtimes/function.py @@ -197,8 +197,8 @@ def generate_nuclio_volumes(self): {"volume": self._volumes[volume_name], "volumeMount": volume_mount} ) - volumes_without_volume_mounts = volume_with_volume_mounts_names.symmetric_difference( - self._volumes.keys() + volumes_without_volume_mounts = ( + volume_with_volume_mounts_names.symmetric_difference(self._volumes.keys()) ) if volumes_without_volume_mounts: raise ValueError( @@ -272,7 +272,11 @@ def add_trigger(self, name, spec): return self def with_source_archive( - self, source, handler="", runtime="", secrets=None, + self, + source, + handler="", + runtime="", + secrets=None, ): """Load nuclio function from remote source :param source: a full path to the nuclio function source (code entry) to load the function from @@ -475,7 +479,9 @@ def add_model(self, name, model_path, **kw): def from_image(self, image): config = nuclio.config.new_config() update_in( - config, "spec.handler", self.spec.function_handler or "main:handler", + config, + "spec.handler", + self.spec.function_handler or "main:handler", ) update_in(config, "spec.image", image) update_in(config, "spec.build.codeEntryType", "image") @@ -625,7 +631,10 @@ def deploy( self.save(versioned=False) self._ensure_run_db() internal_invocation_urls, external_invocation_urls = deploy_nuclio_function( - self, dashboard=dashboard, watch=True, auth_info=auth_info, + self, + dashboard=dashboard, + watch=True, + auth_info=auth_info, ) self.status.internal_invocation_urls = internal_invocation_urls self.status.external_invocation_urls = external_invocation_urls @@ -1009,7 +1018,11 @@ async def _invoke_async(self, tasks, url, headers, secrets, generator): self.store_run(task) task.spec.secret_sources = secrets or [] resp = submit(session, url, task, semaphore, headers=headers) - runs.append(asyncio.ensure_future(resp,)) + runs.append( + asyncio.ensure_future( + resp, + ) + ) for result in asyncio.as_completed(runs): status, resp, logs, task = await result diff --git a/mlrun/runtimes/local.py b/mlrun/runtimes/local.py index 104af37e12..3ae52d8ed4 100644 --- a/mlrun/runtimes/local.py +++ b/mlrun/runtimes/local.py @@ -122,7 +122,11 @@ def remote_handler_wrapper(task, handler, workdir=None): if task and not isinstance(task, dict): task = json.loads(task) - context = MLClientCtx.from_dict(task, autocommit=False, host=socket.gethostname(),) + context = MLClientCtx.from_dict( + task, + autocommit=False, + host=socket.gethostname(), + ) runobj = RunObject.from_dict(task) sout, serr = exec_from_params(handler, runobj, context, workdir) diff --git a/mlrun/runtimes/mpijob/abstract.py b/mlrun/runtimes/mpijob/abstract.py index 0607424762..0474ab4712 100644 --- a/mlrun/runtimes/mpijob/abstract.py +++ b/mlrun/runtimes/mpijob/abstract.py @@ -107,7 +107,10 @@ def spec(self, spec): @abc.abstractmethod def _generate_mpi_job( - self, runobj: RunObject, execution: MLClientCtx, meta: client.V1ObjectMeta, + self, + runobj: RunObject, + execution: MLClientCtx, + meta: client.V1ObjectMeta, ) -> typing.Dict: pass @@ -172,7 +175,8 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): execution.set_state("completed") else: execution.set_state( - "error", f"MpiJob {meta.name} finished with state {status}", + "error", + f"MpiJob {meta.name} finished with state {status}", ) else: txt = f"MpiJob {meta.name} launcher pod {launcher} state {state}" diff --git a/mlrun/runtimes/mpijob/v1.py b/mlrun/runtimes/mpijob/v1.py index 7db0a4d66e..a53a5f2a75 100644 --- a/mlrun/runtimes/mpijob/v1.py +++ b/mlrun/runtimes/mpijob/v1.py @@ -155,7 +155,10 @@ def _enrich_worker_configurations(self, worker_pod_template): ) def _generate_mpi_job( - self, runobj: RunObject, execution: MLClientCtx, meta: client.V1ObjectMeta, + self, + runobj: RunObject, + execution: MLClientCtx, + meta: client.V1ObjectMeta, ) -> dict: pod_labels = deepcopy(meta.labels) pod_labels["mlrun/job"] = meta.name @@ -185,7 +188,9 @@ def _generate_mpi_job( self._update_container(pod_template, "env", extra_env + self.spec.env) if self.spec.image_pull_policy: self._update_container( - pod_template, "imagePullPolicy", self.spec.image_pull_policy, + pod_template, + "imagePullPolicy", + self.spec.image_pull_policy, ) if self.spec.workdir: self._update_container(pod_template, "workingDir", self.spec.workdir) @@ -226,16 +231,22 @@ def _generate_mpi_job( # update the replicas only for workers update_in( - job, "spec.mpiReplicaSpecs.Worker.replicas", self.spec.replicas or 1, + job, + "spec.mpiReplicaSpecs.Worker.replicas", + self.spec.replicas or 1, ) update_in( - job, "spec.cleanPodPolicy", self.spec.clean_pod_policy, + job, + "spec.cleanPodPolicy", + self.spec.clean_pod_policy, ) if execution.get_param("slots_per_worker"): update_in( - job, "spec.slotsPerWorker", execution.get_param("slots_per_worker"), + job, + "spec.slotsPerWorker", + execution.get_param("slots_per_worker"), ) update_in(job, "metadata", meta.to_dict()) diff --git a/mlrun/runtimes/pod.py b/mlrun/runtimes/pod.py index 8f01f905fb..a7e1ca59b5 100644 --- a/mlrun/runtimes/pod.py +++ b/mlrun/runtimes/pod.py @@ -559,8 +559,10 @@ def _add_vault_params_to_spec(self, runobj=None, project=None): logger.warning("No project provided. Cannot add vault parameters") return - service_account_name = mlconf.secret_stores.vault.project_service_account_name.format( - project=project_name + service_account_name = ( + mlconf.secret_stores.vault.project_service_account_name.format( + project=project_name + ) ) project_vault_secret_name = self._get_k8s().get_project_vault_secret_name( diff --git a/mlrun/runtimes/serving.py b/mlrun/runtimes/serving.py index e18861b080..ad33721fb5 100644 --- a/mlrun/runtimes/serving.py +++ b/mlrun/runtimes/serving.py @@ -216,7 +216,12 @@ def spec(self, spec): self._spec = self._verify_dict(spec, "spec", ServingSpec) def set_topology( - self, topology=None, class_name=None, engine=None, exist_ok=False, **class_args, + self, + topology=None, + class_name=None, + engine=None, + exist_ok=False, + **class_args, ) -> Union[RootFlowStep, RouterStep]: """set the serving graph topology (router/flow) and root class or params @@ -610,7 +615,10 @@ def to_mock_server( **kwargs, ) server.init_states( - context=None, namespace=namespace, logger=logger, is_mock=True, + context=None, + namespace=namespace, + logger=logger, + is_mock=True, ) server.init_object(namespace) return server diff --git a/mlrun/runtimes/sparkjob/abstract.py b/mlrun/runtimes/sparkjob/abstract.py index 288286a1c0..10d8053dc3 100644 --- a/mlrun/runtimes/sparkjob/abstract.py +++ b/mlrun/runtimes/sparkjob/abstract.py @@ -343,7 +343,10 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): update_in(job, "spec.driver.labels", pod_labels) update_in(job, "spec.executor.labels", pod_labels) verify_and_update_in( - job, "spec.executor.instances", self.spec.replicas or 1, int, + job, + "spec.executor.instances", + self.spec.replicas or 1, + int, ) if self.spec.node_selector: update_in(job, "spec.nodeSelector", self.spec.node_selector) @@ -430,7 +433,10 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): update_in(job, "spec.executor.gpu.name", gpu_type) if gpu_quantity: verify_and_update_in( - job, "spec.executor.gpu.quantity", gpu_quantity, int, + job, + "spec.executor.gpu.quantity", + gpu_quantity, + int, ) if "limits" in self.spec.driver_resources: if "cpu" in self.spec.driver_resources["limits"]: @@ -456,7 +462,10 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): update_in(job, "spec.driver.gpu.name", gpu_type) if gpu_quantity: verify_and_update_in( - job, "spec.driver.gpu.quantity", gpu_quantity, int, + job, + "spec.driver.gpu.quantity", + gpu_quantity, + int, ) self._enrich_job(job) @@ -475,7 +484,10 @@ def _enrich_job(self, job): raise NotImplementedError() def _submit_job( - self, job, meta, code=None, + self, + job, + meta, + code=None, ): namespace = meta.namespace k8s = self._get_k8s() diff --git a/mlrun/runtimes/utils.py b/mlrun/runtimes/utils.py index c8f8e01fd3..385bfc10cf 100644 --- a/mlrun/runtimes/utils.py +++ b/mlrun/runtimes/utils.py @@ -320,7 +320,11 @@ def generate_function_image_name(project: str, name: str, tag: str) -> str: def fill_function_image_name_template( - registry: str, repository: str, project: str, name: str, tag: str, + registry: str, + repository: str, + project: str, + name: str, + tag: str, ) -> str: image_name_prefix = resolve_function_target_image_name_prefix(project, name) return f"{registry}{repository}/{image_name_prefix}:{tag}" @@ -517,7 +521,8 @@ def enrich_function_from_dict(function, function_dict): function.set_env(env_dict["name"], env_dict["value"]) else: function.set_env( - env_dict["name"], value_from=env_dict["valueFrom"], + env_dict["name"], + value_from=env_dict["valueFrom"], ) elif attribute == "volumes": function.spec.update_vols_and_mounts(override_value, []) diff --git a/mlrun/serving/remote.py b/mlrun/serving/remote.py index 32aca31d8f..6c6bf6c30d 100644 --- a/mlrun/serving/remote.py +++ b/mlrun/serving/remote.py @@ -38,8 +38,7 @@ def get_http_adapter(retries, backoff_factor): class RemoteStep(storey.SendToHttp): - """class for calling remote endpoints - """ + """class for calling remote endpoints""" def __init__( self, @@ -240,8 +239,7 @@ def _get_data(self, data, headers): class BatchHttpRequests(_ConcurrentJobExecution): - """class for calling remote endpoints in parallel - """ + """class for calling remote endpoints in parallel""" def __init__( self, diff --git a/mlrun/serving/routers.py b/mlrun/serving/routers.py index 8b26cc92b1..8ea8e1c5d6 100644 --- a/mlrun/serving/routers.py +++ b/mlrun/serving/routers.py @@ -813,7 +813,13 @@ def __init__( :param kwargs: extra arguments """ super().__init__( - context, name, routes, protocol, url_prefix, health_prefix, **kwargs, + context, + name, + routes, + protocol, + url_prefix, + health_prefix, + **kwargs, ) self.feature_vector_uri = feature_vector_uri @@ -824,7 +830,8 @@ def __init__( def post_init(self, mode="sync"): super().post_init(mode) self._feature_service = mlrun.feature_store.get_online_feature_service( - feature_vector=self.feature_vector_uri, impute_policy=self.impute_policy, + feature_vector=self.feature_vector_uri, + impute_policy=self.impute_policy, ) def preprocess(self, event): @@ -966,7 +973,8 @@ def __init__( def post_init(self, mode="sync"): super().post_init(mode) self._feature_service = mlrun.feature_store.get_online_feature_service( - feature_vector=self.feature_vector_uri, impute_policy=self.impute_policy, + feature_vector=self.feature_vector_uri, + impute_policy=self.impute_policy, ) def preprocess(self, event): diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index 1cc41fb275..3e89c989af 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -1193,7 +1193,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs): def _add_graphviz_flow( - graph, step, source=None, targets=None, + graph, + step, + source=None, + targets=None, ): start_steps, default_final_step, responders = step.check_and_process_graph( allow_empty=True @@ -1230,7 +1233,13 @@ def _add_graphviz_flow( def _generate_graphviz( - step, renderer, filename=None, format=None, source=None, targets=None, **kw, + step, + renderer, + filename=None, + format=None, + source=None, + targets=None, + **kw, ): try: from graphviz import Digraph @@ -1385,7 +1394,9 @@ def _init_async_objects(context, steps): endpoint, stream_path = parse_v3io_path(step.path) stream_path = stream_path.strip("/") step._async_object = storey.StreamTarget( - storey.V3ioDriver(endpoint), stream_path, context=context, + storey.V3ioDriver(endpoint), + stream_path, + context=context, ) else: step._async_object = storey.Map(lambda x: x) diff --git a/mlrun/serving/v1_serving.py b/mlrun/serving/v1_serving.py index c84d127e6d..b92c50aee6 100644 --- a/mlrun/serving/v1_serving.py +++ b/mlrun/serving/v1_serving.py @@ -174,7 +174,11 @@ def nuclio_serving_handler(context, event): actions = "|".join(context.router.keys()) models = "|".join(context.models.keys()) body = f"Got path: {event.path} \n Path must be / \nactions: {actions} \nmodels: {models}" - return context.Response(body=body, content_type="text/plain", status_code=404,) + return context.Response( + body=body, + content_type="text/plain", + status_code=404, + ) return route(context, model_name, event) diff --git a/mlrun/serving/v2_serving.py b/mlrun/serving/v2_serving.py index 5d8cd3fa94..39aa47d95c 100644 --- a/mlrun/serving/v2_serving.py +++ b/mlrun/serving/v2_serving.py @@ -167,31 +167,31 @@ def set_metric(self, name: str, value): def get_model(self, suffix=""): """get the model file(s) and metadata from model store - the method returns a path to the model file and the extra data (dict of dataitem objects) - it also loads the model metadata into the self.model_spec attribute, allowing direct access - to all the model metadata attributes. - - get_model is usually used in the model .load() method to init the model - Examples - -------- - :: - - def load(self): - model_file, extra_data = self.get_model(suffix='.pkl') - self.model = load(open(model_file, "rb")) - categories = extra_data['categories'].as_df() - - Parameters - ---------- - suffix : str - optional, model file suffix (when the model_path is a directory) - - Returns - ------- - str - (local) model file - dict - extra dataitems dictionary + the method returns a path to the model file and the extra data (dict of dataitem objects) + it also loads the model metadata into the self.model_spec attribute, allowing direct access + to all the model metadata attributes. + + get_model is usually used in the model .load() method to init the model + Examples + -------- + :: + + def load(self): + model_file, extra_data = self.get_model(suffix='.pkl') + self.model = load(open(model_file, "rb")) + categories = extra_data['categories'].as_df() + + Parameters + ---------- + suffix : str + optional, model file suffix (when the model_path is a directory) + + Returns + ------- + str + (local) model file + dict + extra dataitems dictionary """ model_file, self.model_spec, extra_dataitems = mlrun.artifacts.get_model( diff --git a/mlrun/utils/helpers.py b/mlrun/utils/helpers.py index 3c91a53b2f..3af4c411b5 100644 --- a/mlrun/utils/helpers.py +++ b/mlrun/utils/helpers.py @@ -810,7 +810,7 @@ def create_exponential_backoff(base=2, max_value=120, scale_factor=1): # This "complex" implementation (unlike the one in linear backoff) is to avoid exponent growing too fast and # risking going behind max_int - next_value = scale_factor * (base ** exponent) + next_value = scale_factor * (base**exponent) if next_value < max_value: exponent += 1 yield next_value diff --git a/mlrun/utils/model_monitoring.py b/mlrun/utils/model_monitoring.py index ae04ee9cea..847c24d3e5 100644 --- a/mlrun/utils/model_monitoring.py +++ b/mlrun/utils/model_monitoring.py @@ -20,7 +20,10 @@ class FunctionURI: def from_string(cls, function_uri): project, uri, tag, hash_key = parse_versioned_object_uri(function_uri) return cls( - project=project, function=uri, tag=tag or None, hash_key=hash_key or None, + project=project, + function=uri, + tag=tag or None, + hash_key=hash_key or None, ) @@ -97,7 +100,7 @@ def parse_model_endpoint_store_prefix(store_prefix: str): def set_project_model_monitoring_credentials( access_key: str, project: Optional[str] = None ): - """ Set the credentials that will be used by the project's model monitoring + """Set the credentials that will be used by the project's model monitoring infrastructure functions. The supplied credentials must have data access diff --git a/mlrun/utils/vault.py b/mlrun/utils/vault.py index f9fd8423c4..1a679466e5 100644 --- a/mlrun/utils/vault.py +++ b/mlrun/utils/vault.py @@ -259,8 +259,8 @@ def init_project_vault_configuration(project): namespace = mlconf.namespace k8s = get_k8s_helper(silent=True) - service_account_name = mlconf.secret_stores.vault.project_service_account_name.format( - project=project + service_account_name = ( + mlconf.secret_stores.vault.project_service_account_name.format(project=project) ) secret_name = k8s.get_project_vault_secret_name( diff --git a/pyproject.toml b/pyproject.toml index 51ead94b19..aa7e7666d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,3 @@ -# please keep this to a minimum - defaults are good -[tool.black] -exclude = ''' -/( - \.git - | \.venv - | \venv -)/ -''' - [tool.isort] profile = "black" multi_line_output = 3 diff --git a/requirements.txt b/requirements.txt index 92dbffed52..754e0087ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,20 +4,12 @@ urllib3>=1.25.4, <1.27 chardet>=3.0.2, <4.0 GitPython~=3.0 aiohttp~=3.8 -click~=7.0 -# kfp ~1.0.1 resolves to 1.0.4, which has google-auth>=1.6.1 which resolves to 2.x which is incompatible with -# google-cloud-storage (from kfp) that is >=1.13.0 and resolves to 1.42.0) and has google-api-core that is -# >=1.29.0,<3.0dev and resolves to 1.31.2 which has google-auth >=1.25.0,<2.0dev which is incompatible -google-auth>=1.25.0, <2.0dev +click~=8.0 # 3.0/3.2 iguazio system uses 1.0.1, but we needed >=1.6.0 to be compatible with k8s>=12.0 to fix scurity issue # since the sdk is still mark as beta (and not stable) I'm limiting to only patch changes kfp~=1.8.0 nest-asyncio~=1.0 ipython~=7.0 -# nuclio-jupyter has notebook>=5.2.0 which resolves to 6.4.0 which has ipykernel without specifier, which from 0.6.0 -# has ipython>=7.23.1 which is incompatible with our ipython specifiers, therefore instsalling ipykernel 5.x before -# nuclio-jupyter -ipykernel~=5.0 nuclio-jupyter~=0.8.22 # >=1.16.5 from pandas 1.2.1 and <1.20.0 because we're hitting the same issue as this one # https://github.com/Azure/MachineLearningNotebooks/issues/1314 diff --git a/tests/api/api/feature_store/base.py b/tests/api/api/feature_store/base.py index 461207240f..3f20a84bb4 100644 --- a/tests/api/api/feature_store/base.py +++ b/tests/api/api/feature_store/base.py @@ -75,7 +75,10 @@ def _assert_diff_as_expected_except_for_specific_metadata( for field in allowed_metadata_fields: exclude_paths.append(f"root['metadata']['{field}']") diff = DeepDiff( - expected_object, actual_object, ignore_order=True, exclude_paths=exclude_paths, + expected_object, + actual_object, + ignore_order=True, + exclude_paths=exclude_paths, ) assert diff == expected_diff diff --git a/tests/api/api/feature_store/test_feature_sets.py b/tests/api/api/feature_store/test_feature_sets.py index 2b6ffa25fb..4949937626 100644 --- a/tests/api/api/feature_store/test_feature_sets.py +++ b/tests/api/api/feature_store/test_feature_sets.py @@ -525,7 +525,10 @@ def test_list_feature_sets_tags(db: Session, client: TestClient) -> None: client, project_name, feature_set["metadata"]["name"], tag, feature_set ) _list_tags_and_assert( - client, "feature_sets", project_name, tags, + client, + "feature_sets", + project_name, + tags, ) diff --git a/tests/api/api/feature_store/test_feature_vectors.py b/tests/api/api/feature_store/test_feature_vectors.py index 3d542a0215..80e2857e9a 100644 --- a/tests/api/api/feature_store/test_feature_vectors.py +++ b/tests/api/api/feature_store/test_feature_vectors.py @@ -404,7 +404,10 @@ def test_list_feature_vectors_tags(db: Session, client: TestClient) -> None: feature_vector, ) _list_tags_and_assert( - client, "feature_vectors", project_name, tags, + client, + "feature_vectors", + project_name, + tags, ) @@ -458,11 +461,16 @@ def _verify_queried_resources( (project, "some-feature-set"), ] assert ( - deepdiff.DeepDiff(expected_resources, resources, ignore_order=True,) == {} + deepdiff.DeepDiff( + expected_resources, + resources, + ignore_order=True, + ) + == {} ) - mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions = unittest.mock.Mock( - side_effect=_verify_queried_resources + mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions = ( + unittest.mock.Mock(side_effect=_verify_queried_resources) ) mlrun.api.api.endpoints.feature_store._verify_feature_vector_features_permissions( mlrun.api.schemas.AuthInfo(), diff --git a/tests/api/api/test_background_tasks.py b/tests/api/api/test_background_tasks.py index 2e79baec4b..5ac9d00fc8 100644 --- a/tests/api/api/test_background_tasks.py +++ b/tests/api/api/test_background_tasks.py @@ -21,7 +21,9 @@ response_model=mlrun.api.schemas.BackgroundTask, ) def create_project_background_task( - project: str, background_tasks: fastapi.BackgroundTasks, failed_task: bool = False, + project: str, + background_tasks: fastapi.BackgroundTasks, + failed_task: bool = False, ): function = bump_counter if failed_task: diff --git a/tests/api/api/test_docs.py b/tests/api/api/test_docs.py index 44bc3b50a5..7afea2f06f 100644 --- a/tests/api/api/test_docs.py +++ b/tests/api/api/test_docs.py @@ -22,7 +22,7 @@ def test_docs( def test_save_openapi_json( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - """"The purpose of the test is to create an openapi.json file that is used to run backward compatibility tests""" + """ "The purpose of the test is to create an openapi.json file that is used to run backward compatibility tests""" response = client.get("openapi.json") path = os.path.abspath(os.getcwd()) if os.getenv("MLRUN_BC_TESTS_OPENAPI_OUTPUT_PATH"): diff --git a/tests/api/api/test_frontend_spec.py b/tests/api/api/test_frontend_spec.py index a3d6896955..ee1bf88e4f 100644 --- a/tests/api/api/test_frontend_spec.py +++ b/tests/api/api/test_frontend_spec.py @@ -82,15 +82,20 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( # no grafana (None returned) so no url mlrun.mlconf.httpdb.authentication.mode = "iguazio" - mlrun.api.utils.clients.iguazio.Client().verify_request_session = unittest.mock.Mock( - return_value=( - mlrun.api.schemas.AuthInfo( - username=None, session="some-session", user_id=None, user_group_ids=[] + mlrun.api.utils.clients.iguazio.Client().verify_request_session = ( + unittest.mock.Mock( + return_value=( + mlrun.api.schemas.AuthInfo( + username=None, + session="some-session", + user_id=None, + user_group_ids=[], + ) ) ) ) - mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = unittest.mock.Mock( - return_value=None + mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = ( + unittest.mock.Mock(return_value=None) ) response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value @@ -100,8 +105,8 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( # happy secnario - grafana url found, verify returned correctly grafana_url = "some-url.com" - mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = unittest.mock.Mock( - return_value=grafana_url + mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url = ( + unittest.mock.Mock(return_value=grafana_url) ) response = client.get("frontend-spec") diff --git a/tests/api/api/test_functions.py b/tests/api/api/test_functions.py index d5c46c1f30..dac8cac789 100644 --- a/tests/api/api/test_functions.py +++ b/tests/api/api/test_functions.py @@ -32,9 +32,11 @@ def test_build_status_pod_not_found(db: Session, client: TestClient): assert response.status_code == HTTPStatus.OK.value mlrun.api.utils.singletons.k8s.get_k8s().v1api = unittest.mock.Mock() - mlrun.api.utils.singletons.k8s.get_k8s().v1api.read_namespaced_pod = unittest.mock.Mock( - side_effect=kubernetes.client.rest.ApiException( - status=HTTPStatus.NOT_FOUND.value + mlrun.api.utils.singletons.k8s.get_k8s().v1api.read_namespaced_pod = ( + unittest.mock.Mock( + side_effect=kubernetes.client.rest.ApiException( + status=HTTPStatus.NOT_FOUND.value + ) ) ) response = client.get( @@ -67,7 +69,10 @@ def test_build_function_with_mlrun_bool(db: Session, client: TestClient): mlrun.api.api.endpoints.functions._build_function = unittest.mock.Mock( return_value=(function, True) ) - response = client.post("build/function", json=request_body,) + response = client.post( + "build/function", + json=request_body, + ) assert response.status_code == HTTPStatus.OK.value assert ( mlrun.api.api.endpoints.functions._build_function.call_args[0][3] diff --git a/tests/api/api/test_grafana_proxy.py b/tests/api/api/test_grafana_proxy.py index cc9bb7f432..7e067902b2 100644 --- a/tests/api/api/test_grafana_proxy.py +++ b/tests/api/api/test_grafana_proxy.py @@ -43,23 +43,28 @@ def test_grafana_proxy_model_endpoints_check_connection( db: Session, client: TestClient ): mlrun.mlconf.httpdb.authentication.mode = "iguazio" - mlrun.api.utils.clients.iguazio.Client().verify_request_session = unittest.mock.Mock( - return_value=( - mlrun.api.schemas.AuthInfo( - username=None, - session="some-session", - data_session="some-session", - user_id=None, - user_group_ids=[], + mlrun.api.utils.clients.iguazio.Client().verify_request_session = ( + unittest.mock.Mock( + return_value=( + mlrun.api.schemas.AuthInfo( + username=None, + session="some-session", + data_session="some-session", + user_id=None, + user_group_ids=[], + ) ) ) ) - response = client.get(url="grafana-proxy/model-endpoints",) + response = client.get( + url="grafana-proxy/model-endpoints", + ) assert response.status_code == 200 @pytest.mark.skipif( - _is_env_params_dont_exist(), reason=_build_skip_message(), + _is_env_params_dont_exist(), + reason=_build_skip_message(), ) def test_grafana_list_endpoints(db: Session, client: TestClient): endpoints_in = [_mock_random_endpoint("active") for _ in range(5)] @@ -102,7 +107,8 @@ def test_grafana_list_endpoints(db: Session, client: TestClient): @pytest.mark.skipif( - _is_env_params_dont_exist(), reason=_build_skip_message(), + _is_env_params_dont_exist(), + reason=_build_skip_message(), ) def test_grafana_individual_feature_analysis(db: Session, client: TestClient): endpoint_data = { @@ -157,7 +163,8 @@ def test_grafana_individual_feature_analysis(db: Session, client: TestClient): @pytest.mark.skipif( - _is_env_params_dont_exist(), reason=_build_skip_message(), + _is_env_params_dont_exist(), + reason=_build_skip_message(), ) def test_grafana_individual_feature_analysis_missing_field_doesnt_fail( db: Session, client: TestClient @@ -218,7 +225,8 @@ def test_grafana_individual_feature_analysis_missing_field_doesnt_fail( @pytest.mark.skipif( - _is_env_params_dont_exist(), reason=_build_skip_message(), + _is_env_params_dont_exist(), + reason=_build_skip_message(), ) def test_grafana_overall_feature_analysis(db: Session, client: TestClient): endpoint_data = { @@ -372,14 +380,17 @@ def cleanup_endpoints(db: Session, client: TestClient): try: # Cleanup TSDB frames.delete( - backend="tsdb", table=tsdb_path, if_missing=fpb2.IGNORE, + backend="tsdb", + table=tsdb_path, + if_missing=fpb2.IGNORE, ) except CreateError: pass @pytest.mark.skipif( - _is_env_params_dont_exist(), reason=_build_skip_message(), + _is_env_params_dont_exist(), + reason=_build_skip_message(), ) def test_grafana_incoming_features(db: Session, client: TestClient): path = config.model_endpoint_monitoring.store_prefixes.default.format( @@ -388,7 +399,9 @@ def test_grafana_incoming_features(db: Session, client: TestClient): _, container, path = parse_model_endpoint_store_prefix(path) frames = get_frames_client( - token=_get_access_key(), container=container, address=config.v3io_framesd, + token=_get_access_key(), + container=container, + address=config.v3io_framesd, ) frames.create(backend="tsdb", table=path, rate="10/m", if_exists=1) diff --git a/tests/api/api/test_model_endpoints.py b/tests/api/api/test_model_endpoints.py index 50e163af07..8e739fa8cd 100644 --- a/tests/api/api/test_model_endpoints.py +++ b/tests/api/api/test_model_endpoints.py @@ -22,27 +22,35 @@ def test_build_kv_cursor_filter_expression(): with pytest.raises(MLRunInvalidArgumentError): mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression("") - filter_expression = mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( - project=TEST_PROJECT + filter_expression = ( + mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( + project=TEST_PROJECT + ) ) assert filter_expression == f"project=='{TEST_PROJECT}'" - filter_expression = mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( - project=TEST_PROJECT, function="test_function", model="test_model" + filter_expression = ( + mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( + project=TEST_PROJECT, function="test_function", model="test_model" + ) ) expected = f"project=='{TEST_PROJECT}' AND function=='test_function' AND model=='test_model'" assert filter_expression == expected - filter_expression = mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( - project=TEST_PROJECT, labels=["lbl1", "lbl2"] + filter_expression = ( + mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( + project=TEST_PROJECT, labels=["lbl1", "lbl2"] + ) ) assert ( filter_expression == f"project=='{TEST_PROJECT}' AND exists(_lbl1) AND exists(_lbl2)" ) - filter_expression = mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( - project=TEST_PROJECT, labels=["lbl1=1", "lbl2=2"] + filter_expression = ( + mlrun.api.crud.ModelEndpoints().build_kv_cursor_filter_expression( + project=TEST_PROJECT, labels=["lbl1=1", "lbl2=2"] + ) ) assert ( filter_expression == f"project=='{TEST_PROJECT}' AND _lbl1=='1' AND _lbl2=='2'" diff --git a/tests/api/api/test_pipelines.py b/tests/api/api/test_pipelines.py index 62ec23eb54..52e06e1b39 100644 --- a/tests/api/api/test_pipelines.py +++ b/tests/api/api/test_pipelines.py @@ -18,8 +18,8 @@ @pytest.fixture def kfp_client_mock(monkeypatch) -> kfp.Client: - mlrun.api.utils.singletons.k8s.get_k8s().is_running_inside_kubernetes_cluster = unittest.mock.Mock( - return_value=True + mlrun.api.utils.singletons.k8s.get_k8s().is_running_inside_kubernetes_cluster = ( + unittest.mock.Mock(return_value=True) ) kfp_client_mock = unittest.mock.Mock() monkeypatch.setattr(kfp, "Client", lambda *args, **kwargs: kfp_client_mock) @@ -67,7 +67,10 @@ def test_list_pipelines_formats( db, expected_runs, format_ ) _mock_list_runs(kfp_client_mock, runs) - response = client.get("projects/*/pipelines", params={"format": format_},) + response = client.get( + "projects/*/pipelines", + params={"format": format_}, + ) expected_response = mlrun.api.schemas.PipelinesOutput( runs=expected_runs, total_size=len(runs), next_page_token=None ) @@ -88,7 +91,8 @@ def test_get_pipeline_formats( api_run_detail = _generate_get_run_mock() _mock_get_run(kfp_client_mock, api_run_detail) response = client.get( - f"projects/*/pipelines/{api_run_detail.run.id}", params={"format": format_}, + f"projects/*/pipelines/{api_run_detail.run.id}", + params={"format": format_}, ) expected_run = mlrun.api.crud.Pipelines()._format_run( db, api_run_detail.to_dict()["run"], format_, api_run_detail.to_dict() @@ -112,7 +116,8 @@ def test_get_pipeline_no_project_opa_validation( api_run_detail = _generate_get_run_mock() _mock_get_run(kfp_client_mock, api_run_detail) response = client.get( - f"projects/*/pipelines/{api_run_detail.run.id}", params={"format": format_}, + f"projects/*/pipelines/{api_run_detail.run.id}", + params={"format": format_}, ) assert ( mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions.call_args[ @@ -221,7 +226,9 @@ def test_create_pipeline_legacy( contents = file.read() _mock_pipelines_creation(kfp_client_mock) response = client.post( - "submit_pipeline", data=contents, headers={"content-type": "application/yaml"}, + "submit_pipeline", + data=contents, + headers={"content-type": "application/yaml"}, ) response_body = response.json() assert response_body["id"] == "some-run-id" @@ -236,7 +243,8 @@ def _generate_get_run_mock() -> kfp_server_api.models.api_run_detail.ApiRunDetai name="run1", description="desc1", pipeline_spec=kfp_server_api.models.api_pipeline_spec.ApiPipelineSpec( - pipeline_id="pipe_id1", workflow_manifest=workflow_manifest, + pipeline_id="pipe_id1", + workflow_manifest=workflow_manifest, ), ), pipeline_runtime=kfp_server_api.models.api_pipeline_runtime.ApiPipelineRuntime( @@ -253,7 +261,8 @@ def _generate_list_runs_mocks(): name="run1", description="desc1", pipeline_spec=kfp_server_api.models.api_pipeline_spec.ApiPipelineSpec( - pipeline_id="pipe_id1", workflow_manifest=workflow_manifest, + pipeline_id="pipe_id1", + workflow_manifest=workflow_manifest, ), ), kfp_server_api.models.api_run.ApiRun( @@ -261,7 +270,8 @@ def _generate_list_runs_mocks(): name="run2", description="desc2", pipeline_spec=kfp_server_api.models.api_pipeline_spec.ApiPipelineSpec( - pipeline_id="pipe_id2", workflow_manifest=workflow_manifest, + pipeline_id="pipe_id2", + workflow_manifest=workflow_manifest, ), ), kfp_server_api.models.api_run.ApiRun( @@ -269,7 +279,8 @@ def _generate_list_runs_mocks(): name="run3", description="desc3", pipeline_spec=kfp_server_api.models.api_pipeline_spec.ApiPipelineSpec( - pipeline_id="pipe_id3", workflow_manifest=workflow_manifest, + pipeline_id="pipe_id3", + workflow_manifest=workflow_manifest, ), ), kfp_server_api.models.api_run.ApiRun( @@ -277,7 +288,8 @@ def _generate_list_runs_mocks(): name="run4", description="desc4", pipeline_spec=kfp_server_api.models.api_pipeline_spec.ApiPipelineSpec( - pipeline_id="pipe_id4", workflow_manifest=workflow_manifest, + pipeline_id="pipe_id4", + workflow_manifest=workflow_manifest, ), ), ] @@ -402,7 +414,9 @@ def _generate_workflow_manifest(with_status=False): def _mock_pipelines_creation(kfp_client_mock: kfp.Client): def _mock_create_experiment(name, description=None, namespace=None): return kfp_server_api.models.ApiExperiment( - id="some-exp-id", name=name, description=description, + id="some-exp-id", + name=name, + description=description, ) def _mock_run_pipeline( @@ -472,7 +486,11 @@ def _assert_list_pipelines_response( ): assert response.status_code == http.HTTPStatus.OK.value assert ( - deepdiff.DeepDiff(expected_response.dict(), response.json(), ignore_order=True,) + deepdiff.DeepDiff( + expected_response.dict(), + response.json(), + ignore_order=True, + ) == {} ) @@ -480,5 +498,10 @@ def _assert_list_pipelines_response( def _assert_get_pipeline_response(expected_response: dict, response): assert response.status_code == http.HTTPStatus.OK.value assert ( - deepdiff.DeepDiff(expected_response, response.json(), ignore_order=True,) == {} + deepdiff.DeepDiff( + expected_response, + response.json(), + ignore_order=True, + ) + == {} ) diff --git a/tests/api/api/test_projects.py b/tests/api/api/test_projects.py index 08a3ab8c00..1c7a3b28ec 100644 --- a/tests/api/api/test_projects.py +++ b/tests/api/api/test_projects.py @@ -82,8 +82,8 @@ def test_get_non_existing_project( not found - which "ruined" the `mlrun.get_or_create_project` logic - so adding a specific test to verify it works """ project = "does-not-exist" - mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions = unittest.mock.Mock( - side_effect=mlrun.errors.MLRunUnauthorizedError("bla") + mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions = ( + unittest.mock.Mock(side_effect=mlrun.errors.MLRunUnauthorizedError("bla")) ) response = client.get(f"projects/{project}") assert response.status_code == HTTPStatus.NOT_FOUND.value @@ -272,11 +272,15 @@ def test_list_and_get_project_summaries( # create schedules for the project schedules_count = 3 _create_schedules( - client, project_name, schedules_count, + client, + project_name, + schedules_count, ) # mock pipelines for the project - running_pipelines_count = _mock_pipelines(project_name,) + running_pipelines_count = _mock_pipelines( + project_name, + ) # list project summaries response = client.get("project-summaries") @@ -393,7 +397,9 @@ def test_delete_project_not_deleting_versioned_objects_multiple_times( } assert len(distinct_artifact_keys) < len(response.json()["artifacts"]) - response = client.get(f"projects/{project_name}/feature-sets",) + response = client.get( + f"projects/{project_name}/feature-sets", + ) assert response.status_code == HTTPStatus.OK.value distinct_feature_set_names = { feature_set["metadata"]["name"] @@ -402,7 +408,9 @@ def test_delete_project_not_deleting_versioned_objects_multiple_times( # ensure there are indeed several versions of the same feature_set name assert len(distinct_feature_set_names) < len(response.json()["feature_sets"]) - response = client.get(f"projects/{project_name}/feature-vectors",) + response = client.get( + f"projects/{project_name}/feature-vectors", + ) assert response.status_code == HTTPStatus.OK.value distinct_feature_vector_names = { feature_vector["metadata"]["name"] @@ -433,8 +441,9 @@ def test_delete_project_not_deleting_versioned_objects_multiple_times( assert mlrun.api.utils.singletons.db.get_db().delete_feature_set.call_count == len( distinct_feature_set_names ) - assert mlrun.api.utils.singletons.db.get_db().delete_feature_vector.call_count == len( - distinct_feature_vector_names + assert ( + mlrun.api.utils.singletons.db.get_db().delete_feature_vector.call_count + == len(distinct_feature_vector_names) ) @@ -507,7 +516,11 @@ def test_list_projects_leader_format( project["data"]["metadata"]["name"] for project in response.json()["projects"] ] assert ( - deepdiff.DeepDiff(project_names, returned_project_names, ignore_order=True,) + deepdiff.DeepDiff( + project_names, + returned_project_names, + ignore_order=True, + ) == {} ) @@ -620,7 +633,9 @@ def test_projects_crud( ) assert ( deepdiff.DeepDiff( - response.json()["metadata"]["labels"], labels_1, ignore_order=True, + response.json()["metadata"]["labels"], + labels_1, + ignore_order=True, ) == {} ) @@ -861,7 +876,8 @@ def _assert_resources_in_project( def _assert_schedules_in_project( - project: str, assert_no_resources: bool = False, + project: str, + assert_no_resources: bool = False, ) -> int: number_of_schedules = len( mlrun.api.utils.singletons.scheduler.get_scheduler()._list_schedules_from_scheduler( @@ -875,7 +891,10 @@ def _assert_schedules_in_project( return number_of_schedules -def _assert_logs_in_project(project: str, assert_no_resources: bool = False,) -> int: +def _assert_logs_in_project( + project: str, + assert_no_resources: bool = False, +) -> int: logs_path = mlrun.api.api.utils.project_logs_path(project) number_of_log_files = 0 if logs_path.exists(): @@ -1022,10 +1041,15 @@ def _list_project_names_and_assert( params = params or {} params["format"] = mlrun.api.schemas.ProjectsFormat.name_only # list - names only - filter by state - response = client.get("projects", params=params,) + response = client.get( + "projects", + params=params, + ) assert ( deepdiff.DeepDiff( - expected_names, response.json()["projects"], ignore_order=True, + expected_names, + response.json()["projects"], + ignore_order=True, ) == {} ) diff --git a/tests/api/api/test_runs.py b/tests/api/api/test_runs.py index bee1477670..83cf0396b2 100644 --- a/tests/api/api/test_runs.py +++ b/tests/api/api/test_runs.py @@ -149,7 +149,9 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None: start_time_to=run_2_update_time.isoformat(), ) assert_time_range_request( - client, [run_1_uid, run_2_uid], start_time_from=run_1_start_time.isoformat(), + client, + [run_1_uid, run_2_uid], + start_time_from=run_1_start_time.isoformat(), ) # all last update time range @@ -160,10 +162,14 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None: last_update_time_to=run_2_update_time, ) assert_time_range_request( - client, [run_1_uid, run_2_uid], last_update_time_from=run_1_update_time, + client, + [run_1_uid, run_2_uid], + last_update_time_from=run_1_update_time, ) assert_time_range_request( - client, [run_1_uid, run_2_uid], last_update_time_to=run_2_update_time, + client, + [run_1_uid, run_2_uid], + last_update_time_to=run_2_update_time, ) # catch only first @@ -174,7 +180,9 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None: start_time_to=between_run_1_and_2, ) assert_time_range_request( - client, [run_1_uid], start_time_to=between_run_1_and_2, + client, + [run_1_uid], + start_time_to=between_run_1_and_2, ) assert_time_range_request( client, @@ -191,7 +199,9 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None: start_time_to=run_2_update_time, ) assert_time_range_request( - client, [run_2_uid], last_update_time_from=run_2_start_time, + client, + [run_2_uid], + last_update_time_from=run_2_start_time, ) @@ -215,10 +225,18 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: mlrun.api.crud.Runs().store_run(db, run, uid, iteration, project) # basic list, all projects, all iterations so 3 projects * 3 names * 3 uids * 3 iterations = 81 - runs = _list_and_assert_objects(client, {"project": "*"}, 81,) + runs = _list_and_assert_objects( + client, + {"project": "*"}, + 81, + ) # basic list, specific project, only iteration 0, so 3 names * 3 uids = 9 - runs = _list_and_assert_objects(client, {"project": projects[0], "iter": False}, 9,) + runs = _list_and_assert_objects( + client, + {"project": projects[0], "iter": False}, + 9, + ) # partioned list, specific project, 1 row per partition by default, so 3 names * 1 row = 3 runs = _list_and_assert_objects( diff --git a/tests/api/api/test_runtime_resources.py b/tests/api/api/test_runtime_resources.py index 54a65f9f54..dcba8e999d 100644 --- a/tests/api/api/test_runtime_resources.py +++ b/tests/api/api/test_runtime_resources.py @@ -34,11 +34,15 @@ def test_list_runtimes_resources_opa_filtering( ) _mock_opa_filter_and_assert_list_response( - client, grouped_by_project_runtime_resources_output, [project_3], + client, + grouped_by_project_runtime_resources_output, + [project_3], ) _mock_opa_filter_and_assert_list_response( - client, grouped_by_project_runtime_resources_output, [project_2], + client, + grouped_by_project_runtime_resources_output, + [project_2], ) @@ -86,7 +90,14 @@ def test_list_runtimes_resources_group_by_job( ][mlrun.runtimes.RuntimeKinds.mpijob].dict() }, } - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) def test_list_runtimes_resources_no_group_by( @@ -110,7 +121,9 @@ def test_list_runtimes_resources_no_group_by( mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions = unittest.mock.Mock( side_effect=lambda _, resources, *args, **kwargs: resources ) - response = client.get("projects/*/runtime-resources",) + response = client.get( + "projects/*/runtime-resources", + ) body = response.json() expected_body = [ mlrun.api.schemas.KindRuntimeResources( @@ -147,7 +160,14 @@ def test_list_runtimes_resources_no_group_by( ), ).dict(), ] - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) def test_list_runtime_resources_no_resources( @@ -160,7 +180,9 @@ def test_list_runtime_resources_no_resources( mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions = unittest.mock.Mock( return_value=[] ) - response = client.get("projects/*/runtime-resources",) + response = client.get( + "projects/*/runtime-resources", + ) body = response.json() assert body == [] response = client.get( @@ -191,7 +213,14 @@ def test_list_runtime_resources_no_resources( kind=mlrun.runtimes.RuntimeKinds.job, resources=mlrun.api.schemas.RuntimeResources(), ).dict() - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) def test_list_runtime_resources_filter_by_kind( @@ -237,13 +266,27 @@ def test_list_runtime_resources_filter_by_kind( ), ).dict() expected_body = [expected_runtime_resources] - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) # test legacy endpoint response = client.get(f"runtimes/{mlrun.runtimes.RuntimeKinds.job}") body = response.json() expected_body = expected_runtime_resources - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) def test_delete_runtime_resources_nothing_allowed( @@ -309,15 +352,28 @@ def test_delete_runtime_resources_opa_filtering( _mock_runtime_handlers_delete_resources( mlrun.runtimes.RuntimeKinds.runtime_with_handlers(), allowed_projects ) - response = client.delete("projects/*/runtime-resources",) + response = client.delete( + "projects/*/runtime-resources", + ) body = response.json() - expected_body = _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( - allowed_projects, grouped_by_project_runtime_resources_output + expected_body = ( + _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( + allowed_projects, grouped_by_project_runtime_resources_output + ) + ) + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} ) - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} # legacy endpoint - response = client.delete("runtimes",) + response = client.delete( + "runtimes", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -344,15 +400,28 @@ def test_delete_runtime_resources_with_legacy_builder_pod_opa_filtering( _mock_runtime_handlers_delete_resources( mlrun.runtimes.RuntimeKinds.runtime_with_handlers(), allowed_projects ) - response = client.delete("projects/*/runtime-resources",) + response = client.delete( + "projects/*/runtime-resources", + ) body = response.json() - expected_body = _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( - [""], grouped_by_project_runtime_resources_output + expected_body = ( + _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( + [""], grouped_by_project_runtime_resources_output + ) + ) + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} ) - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} # legacy endpoint - response = client.delete("runtimes",) + response = client.delete( + "runtimes", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -371,8 +440,10 @@ def test_delete_runtime_resources_with_kind( ) = _generate_grouped_by_project_runtime_resources_output() kind = mlrun.runtimes.RuntimeKinds.job - grouped_by_project_runtime_resources_output = _filter_kind_from_grouped_by_project_runtime_resources_output( - kind, grouped_by_project_runtime_resources_output + grouped_by_project_runtime_resources_output = ( + _filter_kind_from_grouped_by_project_runtime_resources_output( + kind, grouped_by_project_runtime_resources_output + ) ) mlrun.api.crud.RuntimeResources().list_runtime_resources = unittest.mock.Mock( return_value=grouped_by_project_runtime_resources_output @@ -383,15 +454,27 @@ def test_delete_runtime_resources_with_kind( return_value=allowed_projects ) _mock_runtime_handlers_delete_resources([kind], allowed_projects) - response = client.delete("projects/*/runtime-resources", params={"kind": kind},) + response = client.delete( + "projects/*/runtime-resources", + params={"kind": kind}, + ) body = response.json() expected_body = _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_output( allowed_projects, kind, grouped_by_project_runtime_resources_output ) - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) # legacy endpoint - response = client.delete(f"runtimes/{kind}",) + response = client.delete( + f"runtimes/{kind}", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -428,21 +511,32 @@ def test_delete_runtime_resources_with_object_id( ) _mock_runtime_handlers_delete_resources([kind], [project_1]) response = client.delete( - "projects/*/runtime-resources", params={"kind": kind, "object-id": object_id}, + "projects/*/runtime-resources", + params={"kind": kind, "object-id": object_id}, ) body = response.json() expected_body = _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_output( [project_1], kind, grouped_by_project_runtime_resources_output, structured=False ) - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} + ) # legacy endpoint - response = client.delete(f"runtimes/{kind}/{object_id}",) + response = client.delete( + f"runtimes/{kind}/{object_id}", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value def _mock_runtime_handlers_delete_resources( - kinds: typing.List[str], allowed_projects: typing.List[str], + kinds: typing.List[str], + allowed_projects: typing.List[str], ): def _assert_delete_resources_label_selector( db, @@ -467,18 +561,26 @@ def _assert_delete_resources_label_selector( def _assert_empty_responses_in_delete_endpoints(client: fastapi.testclient.TestClient): - response = client.delete("projects/*/runtime-resources",) + response = client.delete( + "projects/*/runtime-resources", + ) body = response.json() assert body == {} # legacy endpoints - response = client.delete("runtimes",) + response = client.delete( + "runtimes", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value - response = client.delete(f"runtimes/{mlrun.runtimes.RuntimeKinds.job}",) + response = client.delete( + f"runtimes/{mlrun.runtimes.RuntimeKinds.job}", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value - response = client.delete(f"runtimes/{mlrun.runtimes.RuntimeKinds.job}/some-id",) + response = client.delete( + f"runtimes/{mlrun.runtimes.RuntimeKinds.job}/some-id", + ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -635,10 +737,19 @@ def _mock_opa_filter_and_assert_list_response( params={"group-by": mlrun.api.schemas.ListRuntimeResourcesGroupByField.project}, ) body = response.json() - expected_body = _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( - opa_filter_response, grouped_by_project_runtime_resources_output + expected_body = ( + _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( + opa_filter_response, grouped_by_project_runtime_resources_output + ) + ) + assert ( + deepdiff.DeepDiff( + body, + expected_body, + ignore_order=True, + ) + == {} ) - assert deepdiff.DeepDiff(body, expected_body, ignore_order=True,) == {} def _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_output( @@ -647,8 +758,10 @@ def _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_ grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, structured: bool = False, ): - filtered_output = _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( - allowed_projects, grouped_by_project_runtime_resources_output, structured + filtered_output = ( + _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( + allowed_projects, grouped_by_project_runtime_resources_output, structured + ) ) return _filter_kind_from_grouped_by_project_runtime_resources_output( filter_kind, filtered_output diff --git a/tests/api/api/test_submit.py b/tests/api/api/test_submit.py index 062150cd61..6bcc82487f 100644 --- a/tests/api/api/test_submit.py +++ b/tests/api/api/test_submit.py @@ -63,8 +63,8 @@ def pod_create_mock(): authenticate_request_orig_function = ( mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request ) - mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request = unittest.mock.Mock( - return_value=auth_info_mock + mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request = ( + unittest.mock.Mock(return_value=auth_info_mock) ) yield get_k8s().create_pod diff --git a/tests/api/conftest.py b/tests/api/conftest.py index e15400da23..ee9190c82f 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -126,7 +126,9 @@ def get_expected_env_variables_from_secrets( def assert_project_secrets(self, project: str, secrets: dict): assert ( deepdiff.DeepDiff( - self.project_secrets_map[project], secrets, ignore_order=True, + self.project_secrets_map[project], + secrets, + ignore_order=True, ) == {} ) @@ -168,20 +170,20 @@ def k8s_secrets_mock(client: TestClient) -> K8sSecretsMock: for name in mocked_function_names } - mlrun.api.utils.singletons.k8s.get_k8s().is_running_inside_kubernetes_cluster = unittest.mock.Mock( - side_effect=k8s_secrets_mock.is_running_in_k8s_cluster + mlrun.api.utils.singletons.k8s.get_k8s().is_running_inside_kubernetes_cluster = ( + unittest.mock.Mock(side_effect=k8s_secrets_mock.is_running_in_k8s_cluster) ) - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_keys = unittest.mock.Mock( - side_effect=k8s_secrets_mock.get_project_secret_keys + mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_keys = ( + unittest.mock.Mock(side_effect=k8s_secrets_mock.get_project_secret_keys) ) - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_data = unittest.mock.Mock( - side_effect=k8s_secrets_mock.get_project_secret_data + mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_data = ( + unittest.mock.Mock(side_effect=k8s_secrets_mock.get_project_secret_data) ) mlrun.api.utils.singletons.k8s.get_k8s().store_project_secrets = unittest.mock.Mock( side_effect=k8s_secrets_mock.store_project_secrets ) - mlrun.api.utils.singletons.k8s.get_k8s().delete_project_secrets = unittest.mock.Mock( - side_effect=k8s_secrets_mock.delete_project_secrets + mlrun.api.utils.singletons.k8s.get_k8s().delete_project_secrets = ( + unittest.mock.Mock(side_effect=k8s_secrets_mock.delete_project_secrets) ) yield k8s_secrets_mock diff --git a/tests/api/crud/test_secrets.py b/tests/api/crud/test_secrets.py index a59052aaf1..cb01f5534d 100644 --- a/tests/api/crud/test_secrets.py +++ b/tests/api/crud/test_secrets.py @@ -537,16 +537,23 @@ def test_secrets_crud_internal_secrets( # delete regular secret - pass mlrun.api.crud.Secrets().delete_secrets( - project, provider, [regular_secret_key], + project, + provider, + [regular_secret_key], ) # delete with empty list (delete all) - shouldn't delete internal mlrun.api.crud.Secrets().delete_secrets( - project, provider, [], + project, + provider, + [], ) # list to verify - only internal should remain secrets_data = mlrun.api.crud.Secrets().list_secrets( - project, provider, allow_secrets_from_k8s=True, allow_internal_secrets=True, + project, + provider, + allow_secrets_from_k8s=True, + allow_internal_secrets=True, ) assert ( deepdiff.DeepDiff( @@ -560,7 +567,9 @@ def test_secrets_crud_internal_secrets( # delete internal secret without allow - fail with pytest.raises(mlrun.errors.MLRunAccessDeniedError): mlrun.api.crud.Secrets().delete_secrets( - project, provider, [internal_secret_key], + project, + provider, + [internal_secret_key], ) # delete internal secret with allow - pass @@ -571,7 +580,14 @@ def test_secrets_crud_internal_secrets( secrets_data = mlrun.api.crud.Secrets().list_secrets( project, provider, allow_secrets_from_k8s=True ) - assert deepdiff.DeepDiff(secrets_data.secrets, {}, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + secrets_data.secrets, + {}, + ignore_order=True, + ) + == {} + ) # store internal secret again to verify deletion with empty list with allow - pass mlrun.api.crud.Secrets().store_secrets( @@ -583,10 +599,20 @@ def test_secrets_crud_internal_secrets( ) # delete with empty list (delete all) with allow - nothing should remain mlrun.api.crud.Secrets().delete_secrets( - project, provider, [], allow_internal_secrets=True, + project, + provider, + [], + allow_internal_secrets=True, ) # list to verify secrets_data = mlrun.api.crud.Secrets().list_secrets( project, provider, allow_secrets_from_k8s=True ) - assert deepdiff.DeepDiff(secrets_data.secrets, {}, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + secrets_data.secrets, + {}, + ignore_order=True, + ) + == {} + ) diff --git a/tests/api/db/test_artifacts.py b/tests/api/db/test_artifacts.py index cd948ca428..ef572d1424 100644 --- a/tests/api/db/test_artifacts.py +++ b/tests/api/db/test_artifacts.py @@ -30,10 +30,16 @@ def test_list_artifact_name_filter(db: DBInterface, db_session: Session): uid = "artifact_uid" db.store_artifact( - db_session, artifact_name_1, artifact_1, uid, + db_session, + artifact_name_1, + artifact_1, + uid, ) db.store_artifact( - db_session, artifact_name_2, artifact_2, uid, + db_session, + artifact_name_2, + artifact_2, + uid, ) artifacts = db.list_artifacts(db_session) assert len(artifacts) == 2 @@ -102,10 +108,16 @@ def test_list_artifact_kind_filter(db: DBInterface, db_session: Session): uid = "artifact_uid" db.store_artifact( - db_session, artifact_name_1, artifact_1, uid, + db_session, + artifact_name_1, + artifact_1, + uid, ) db.store_artifact( - db_session, artifact_name_2, artifact_2, uid, + db_session, + artifact_name_2, + artifact_2, + uid, ) artifacts = db.list_artifacts(db_session) assert len(artifacts) == 2 @@ -139,16 +151,28 @@ def test_list_artifact_category_filter(db: DBInterface, db_session: Session): uid = "artifact_uid" db.store_artifact( - db_session, artifact_name_1, artifact_1, uid, + db_session, + artifact_name_1, + artifact_1, + uid, ) db.store_artifact( - db_session, artifact_name_2, artifact_2, uid, + db_session, + artifact_name_2, + artifact_2, + uid, ) db.store_artifact( - db_session, artifact_name_3, artifact_3, uid, + db_session, + artifact_name_3, + artifact_3, + uid, ) db.store_artifact( - db_session, artifact_name_4, artifact_4, uid, + db_session, + artifact_name_4, + artifact_4, + uid, ) artifacts = db.list_artifacts(db_session) assert len(artifacts) == 4 @@ -182,10 +206,16 @@ def test_store_artifact_tagging(db: DBInterface, db_session: Session): artifact_1_with_kind_uid = "artifact_uid_2" db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_1_uid, + db_session, + artifact_1_key, + artifact_1_body, + artifact_1_uid, ) db.store_artifact( - db_session, artifact_1_key, artifact_1_with_kind_body, artifact_1_with_kind_uid, + db_session, + artifact_1_key, + artifact_1_with_kind_body, + artifact_1_with_kind_uid, ) artifact = db.read_artifact(db_session, artifact_1_key, tag="latest") assert artifact["kind"] == artifact_1_kind @@ -211,19 +241,41 @@ def test_store_artifact_restoring_multiple_tags(db: DBInterface, db_session: Ses artifact_2_tag = "artifact_tag_2" db.store_artifact( - db_session, artifact_key, artifact_1_body, artifact_1_uid, tag=artifact_1_tag, + db_session, + artifact_key, + artifact_1_body, + artifact_1_uid, + tag=artifact_1_tag, ) db.store_artifact( - db_session, artifact_key, artifact_2_body, artifact_2_uid, tag=artifact_2_tag, + db_session, + artifact_key, + artifact_2_body, + artifact_2_uid, + tag=artifact_2_tag, ) artifacts = db.list_artifacts(db_session, artifact_key, tag="*") assert len(artifacts) == 2 expected_uids = [artifact_1_uid, artifact_2_uid] uids = [artifact["metadata"]["uid"] for artifact in artifacts] - assert deepdiff.DeepDiff(expected_uids, uids, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + expected_uids, + uids, + ignore_order=True, + ) + == {} + ) expected_tags = [artifact_1_tag, artifact_2_tag] tags = [artifact["tag"] for artifact in artifacts] - assert deepdiff.DeepDiff(expected_tags, tags, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + expected_tags, + tags, + ignore_order=True, + ) + == {} + ) artifact = db.read_artifact(db_session, artifact_key, tag=artifact_1_tag) assert artifact["metadata"]["uid"] == artifact_1_uid assert artifact["tag"] == artifact_1_tag @@ -251,10 +303,18 @@ def test_read_artifact_tag_resolution(db: DBInterface, db_session: Session): artifact_2_tag = "artifact_tag_2" db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_uid, tag=artifact_1_tag, + db_session, + artifact_1_key, + artifact_1_body, + artifact_uid, + tag=artifact_1_tag, ) db.store_artifact( - db_session, artifact_2_key, artifact_2_body, artifact_uid, tag=artifact_2_tag, + db_session, + artifact_2_key, + artifact_2_body, + artifact_uid, + tag=artifact_2_tag, ) with pytest.raises(mlrun.errors.MLRunNotFoundError): db.read_artifact(db_session, artifact_1_key, artifact_2_tag) @@ -285,10 +345,18 @@ def test_delete_artifacts_tag_filter(db: DBInterface, db_session: Session): artifact_2_tag = "artifact_tag_two" db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_1_uid, tag=artifact_1_tag, + db_session, + artifact_1_key, + artifact_1_body, + artifact_1_uid, + tag=artifact_1_tag, ) db.store_artifact( - db_session, artifact_2_key, artifact_2_body, artifact_2_uid, tag=artifact_2_tag, + db_session, + artifact_2_key, + artifact_2_body, + artifact_2_uid, + tag=artifact_2_tag, ) db.del_artifacts(db_session, tag=artifact_1_tag) artifacts = db.list_artifacts(db_session, tag=artifact_1_tag) @@ -314,18 +382,32 @@ def test_list_artifacts_exact_name_match(db: DBInterface, db_session: Session): # Store each twice - once with no iter, and once with an iter db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_1_uid, + db_session, + artifact_1_key, + artifact_1_body, + artifact_1_uid, ) artifact_1_body["iter"] = 42 db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_1_uid, iter=42, + db_session, + artifact_1_key, + artifact_1_body, + artifact_1_uid, + iter=42, ) db.store_artifact( - db_session, artifact_2_key, artifact_2_body, artifact_2_uid, + db_session, + artifact_2_key, + artifact_2_body, + artifact_2_uid, ) artifact_2_body["iter"] = 42 db.store_artifact( - db_session, artifact_2_key, artifact_2_body, artifact_2_uid, iter=42, + db_session, + artifact_2_key, + artifact_2_body, + artifact_2_uid, + iter=42, ) def _list_and_assert_count(key, count, iter=None): @@ -364,7 +446,11 @@ def _generate_artifact_with_iterations( artifact_body["link_iteration"] = best_iter artifact_body["iter"] = iter db.store_artifact( - db_session, key, artifact_body, uid, iter=iter, + db_session, + key, + artifact_body, + uid, + iter=iter, ) @@ -454,7 +540,8 @@ def test_list_artifacts_best_iter(db: DBInterface, db_session: Session): indirect=["data_migration_db", "db_session"], ) def test_data_migration_fix_artifact_tags_duplications( - data_migration_db: DBInterface, db_session: Session, + data_migration_db: DBInterface, + db_session: Session, ): def _buggy_tag_artifacts(session, objs, project: str, name: str): # This is the function code that was used before we did the fix and added the data migration @@ -499,10 +586,16 @@ def _upsert(session, obj, ignore=False): ) data_migration_db.store_artifact( - db_session, artifact_1_key, artifact_1_body, artifact_1_uid, + db_session, + artifact_1_key, + artifact_1_body, + artifact_1_uid, ) data_migration_db.store_artifact( - db_session, artifact_1_key, artifact_1_with_kind_body, artifact_1_with_kind_uid, + db_session, + artifact_1_key, + artifact_1_with_kind_body, + artifact_1_with_kind_uid, ) data_migration_db.store_artifact( db_session, artifact_2_key, artifact_2_body, artifact_2_uid, tag="not-latest" @@ -609,7 +702,8 @@ def _upsert(session, obj, ignore=False): indirect=["data_migration_db", "db_session"], ) def test_data_migration_fix_datasets_large_previews( - data_migration_db: DBInterface, db_session: Session, + data_migration_db: DBInterface, + db_session: Session, ): artifact_with_valid_preview_key = "artifact-with-valid-preview-key" artifact_with_valid_preview_uid = "artifact-with-valid-preview-uid" diff --git a/tests/api/db/test_projects.py b/tests/api/db/test_projects.py index 20a51fd7b1..dc784c7d92 100644 --- a/tests/api/db/test_projects.py +++ b/tests/api/db/test_projects.py @@ -20,7 +20,8 @@ "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_get_project( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" @@ -42,7 +43,9 @@ def test_get_project( assert project_output.spec.description == project_description assert ( deepdiff.DeepDiff( - project_labels, project_output.metadata.labels, ignore_order=True, + project_labels, + project_output.metadata.labels, + ignore_order=True, ) == {} ) @@ -53,7 +56,8 @@ def test_get_project( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_get_project_with_pre_060_record( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project_name" _generate_and_insert_pre_060_record(db_session, project_name) @@ -61,7 +65,10 @@ def test_get_project_with_pre_060_record( db_session.query(Project).filter(Project.name == project_name).one() ) assert pre_060_record.full_object is None - project = db.get_project(db_session, project_name,) + project = db.get_project( + db_session, + project_name, + ) assert project.metadata.name == project_name updated_record = ( db_session.query(Project).filter(Project.name == project_name).one() @@ -75,7 +82,8 @@ def test_get_project_with_pre_060_record( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_data_migration_enrich_project_state( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): for i in range(10): project_name = f"project-name-{i}" @@ -109,7 +117,8 @@ def _generate_and_insert_pre_060_record( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_list_project( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): expected_projects = [ {"name": "project-name-1"}, @@ -152,7 +161,8 @@ def test_list_project( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_list_project_names_filter( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_names = ["project-1", "project-2", "project-3", "project-4", "project-5"] @@ -171,12 +181,18 @@ def test_list_project_names_filter( ) assert ( - deepdiff.DeepDiff(filter_names, projects_output.projects, ignore_order=True,) + deepdiff.DeepDiff( + filter_names, + projects_output.projects, + ignore_order=True, + ) == {} ) projects_output = db.list_projects( - db_session, format_=mlrun.api.schemas.ProjectsFormat.name_only, names=[], + db_session, + format_=mlrun.api.schemas.ProjectsFormat.name_only, + names=[], ) assert projects_output.projects == [] @@ -187,7 +203,8 @@ def test_list_project_names_filter( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_create_project( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" @@ -200,7 +217,9 @@ def test_create_project( db_session, mlrun.api.schemas.Project( metadata=mlrun.api.schemas.ProjectMetadata( - name=project_name, created=project_created, labels=project_labels, + name=project_name, + created=project_created, + labels=project_labels, ), spec=mlrun.api.schemas.ProjectSpec(description=project_description), ), @@ -213,7 +232,9 @@ def test_create_project( assert project_output.metadata.created != project_created assert ( deepdiff.DeepDiff( - project_labels, project_output.metadata.labels, ignore_order=True, + project_labels, + project_output.metadata.labels, + ignore_order=True, ) == {} ) @@ -224,7 +245,8 @@ def test_create_project( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_store_project_creation( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" @@ -237,7 +259,9 @@ def test_store_project_creation( project_name, mlrun.api.schemas.Project( metadata=mlrun.api.schemas.ProjectMetadata( - name=project_name, created=project_created, labels=project_labels, + name=project_name, + created=project_created, + labels=project_labels, ), spec=mlrun.api.schemas.ProjectSpec(description=project_description), ), @@ -249,7 +273,9 @@ def test_store_project_creation( assert project_output.metadata.created != project_created assert ( deepdiff.DeepDiff( - project_labels, project_output.metadata.labels, ignore_order=True, + project_labels, + project_output.metadata.labels, + ignore_order=True, ) == {} ) @@ -260,7 +286,8 @@ def test_store_project_creation( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_store_project_update( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" @@ -272,7 +299,9 @@ def test_store_project_update( db_session, mlrun.api.schemas.Project( metadata=mlrun.api.schemas.ProjectMetadata( - name=project_name, created=project_created, labels=project_labels, + name=project_name, + created=project_created, + labels=project_labels, ), spec=mlrun.api.schemas.ProjectSpec(description=project_description), ), @@ -298,7 +327,8 @@ def test_store_project_update( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_patch_project( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" @@ -330,7 +360,9 @@ def test_patch_project( assert project_output.metadata.created != project_created assert ( deepdiff.DeepDiff( - patched_project_labels, project_output.metadata.labels, ignore_order=True, + patched_project_labels, + project_output.metadata.labels, + ignore_order=True, ) == {} ) @@ -341,7 +373,8 @@ def test_patch_project( "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] ) def test_delete_project( - db: DBInterface, db_session: sqlalchemy.orm.Session, + db: DBInterface, + db_session: sqlalchemy.orm.Session, ): project_name = "project-name" project_description = "some description" diff --git a/tests/api/db/test_runs.py b/tests/api/db/test_runs.py index 2ac394a774..7b64de3bb4 100644 --- a/tests/api/db/test_runs.py +++ b/tests/api/db/test_runs.py @@ -241,7 +241,11 @@ def test_store_and_update_run_update_name_failure(db: DBInterface, db_session: S ): run["metadata"]["name"] = "new-name" db.store_run( - db_session, run, uid, project, iteration, + db_session, + run, + uid, + project, + iteration, ) with pytest.raises( @@ -249,7 +253,11 @@ def test_store_and_update_run_update_name_failure(db: DBInterface, db_session: S match="Changing name for an existing run is invalid", ): db.update_run( - db_session, {"metadata.name": "new-name"}, uid, project, iteration, + db_session, + {"metadata.name": "new-name"}, + uid, + project, + iteration, ) @@ -263,7 +271,9 @@ def test_list_runs_limited_unsorted_failure(db: DBInterface, db_session: Session match="Limiting the number of returned records without sorting will provide non-deterministic results", ): db.list_runs( - db_session, sort=False, last=1, + db_session, + sort=False, + last=1, ) diff --git a/tests/api/runtime_handlers/base.py b/tests/api/runtime_handlers/base.py index 1c10cfa2ab..fd544092b6 100644 --- a/tests/api/runtime_handlers/base.py +++ b/tests/api/runtime_handlers/base.py @@ -166,7 +166,8 @@ def _assert_runtime_handler_list_resources( resources = runtime_handler.list_resources(project, group_by=group_by) crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() get_k8s().v1api.list_namespaced_pod.assert_called_once_with( - get_k8s().resolve_namespace(), label_selector=label_selector, + get_k8s().resolve_namespace(), + label_selector=label_selector, ) if expected_crds: get_k8s().crdapi.list_namespaced_custom_object.assert_called_once_with( @@ -178,7 +179,8 @@ def _assert_runtime_handler_list_resources( ) if expected_services: get_k8s().v1api.list_namespaced_service.assert_called_once_with( - get_k8s().resolve_namespace(), label_selector=label_selector, + get_k8s().resolve_namespace(), + label_selector=label_selector, ) assertion_func( self, @@ -289,7 +291,9 @@ def _assert_resource_in_response_resources( ) assert ( deepdiff.DeepDiff( - resource.status, expected_resource["status"], ignore_order=True, + resource.status, + expected_resource["status"], + ignore_order=True, ) == {} ) @@ -474,7 +478,8 @@ def _assert_run_logs( ): if logger_pod_name is not None: get_k8s().v1api.read_namespaced_pod_log.assert_called_once_with( - name=logger_pod_name, namespace=get_k8s().resolve_namespace(), + name=logger_pod_name, + namespace=get_k8s().resolve_namespace(), ) _, log = crud.Logs().get_logs(db, project, uid, source=LogSources.PERSISTENCY) assert log == expected_log.encode() diff --git a/tests/api/runtime_handlers/test_daskjob.py b/tests/api/runtime_handlers/test_daskjob.py index cdfb0fd0c7..cdd0ed55bd 100644 --- a/tests/api/runtime_handlers/test_daskjob.py +++ b/tests/api/runtime_handlers/test_daskjob.py @@ -28,10 +28,14 @@ def custom_setup(self): scheduler_pod_name = "mlrun-mydask-d7656bc1-0n4z9z" self.running_scheduler_pod = self._generate_pod( - scheduler_pod_name, scheduler_pod_labels, PodPhases.running, + scheduler_pod_name, + scheduler_pod_labels, + PodPhases.running, ) self.completed_scheduler_pod = self._generate_pod( - scheduler_pod_name, scheduler_pod_labels, PodPhases.succeeded, + scheduler_pod_name, + scheduler_pod_labels, + PodPhases.succeeded, ) worker_pod_labels = { @@ -48,10 +52,14 @@ def custom_setup(self): worker_pod_name = "mlrun-mydask-d7656bc1-0pqbnc" self.running_worker_pod = self._generate_pod( - worker_pod_name, worker_pod_labels, PodPhases.running, + worker_pod_name, + worker_pod_labels, + PodPhases.running, ) self.completed_worker_pod = self._generate_pod( - worker_pod_name, worker_pod_labels, PodPhases.succeeded, + worker_pod_name, + worker_pod_labels, + PodPhases.succeeded, ) service_name = "mlrun-mydask-d7656bc1-0" @@ -73,7 +81,9 @@ def test_list_resources(self, db: Session, client: TestClient): pods = self._mock_list_resources_pods() services = self._mock_list_services([self.cluster_service]) self._assert_runtime_handler_list_resources( - RuntimeKinds.dask, expected_pods=pods, expected_services=services, + RuntimeKinds.dask, + expected_pods=pods, + expected_services=services, ) def test_delete_resources_completed_cluster(self, db: Session, client: TestClient): diff --git a/tests/api/runtime_handlers/test_kubejob.py b/tests/api/runtime_handlers/test_kubejob.py index 7f37d0aef4..f0b5cfedbb 100644 --- a/tests/api/runtime_handlers/test_kubejob.py +++ b/tests/api/runtime_handlers/test_kubejob.py @@ -69,7 +69,9 @@ def test_list_resources_grouped_by(self, db: Session, client: TestClient): ]: pods = self._mock_list_resources_pods() self._assert_runtime_handler_list_resources( - RuntimeKinds.job, expected_pods=pods, group_by=group_by, + RuntimeKinds.job, + expected_pods=pods, + group_by=group_by, ) def test_list_resources_grouped_by_project_with_non_project_resources( @@ -110,7 +112,11 @@ def test_delete_resources_completed_pod(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.completed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.completed_job_pod.metadata.name, ) def test_delete_resources_completed_builder_pod( @@ -190,7 +196,11 @@ def test_delete_resources_with_force(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.running ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.running_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.running_job_pod.metadata.name, ) def test_monitor_run_completed_pod(self, db: Session, client: TestClient): @@ -216,7 +226,11 @@ def test_monitor_run_completed_pod(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.completed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.completed_job_pod.metadata.name, ) def test_monitor_run_failed_pod(self, db: Session, client: TestClient): @@ -240,7 +254,11 @@ def test_monitor_run_failed_pod(self, db: Session, client: TestClient): ) self._assert_run_reached_state(db, self.project, self.run_uid, RunStates.error) self._assert_run_logs( - db, self.project, self.run_uid, log, self.failed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.failed_job_pod.metadata.name, ) def test_monitor_run_no_pods(self, db: Session, client: TestClient): @@ -289,7 +307,11 @@ def test_monitor_run_overriding_terminal_state( ) self._assert_run_reached_state(db, self.project, self.run_uid, RunStates.error) self._assert_run_logs( - db, self.project, self.run_uid, log, self.completed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.completed_job_pod.metadata.name, ) def test_monitor_run_debouncing_non_terminal_state( @@ -303,8 +325,8 @@ def test_monitor_run_debouncing_non_terminal_state( original_update_run_updated_time = ( mlrun.api.utils.singletons.db.get_db()._update_run_updated_time ) - mlrun.api.utils.singletons.db.get_db()._update_run_updated_time = tests.conftest.freeze( - original_update_run_updated_time, now=now_date() + mlrun.api.utils.singletons.db.get_db()._update_run_updated_time = ( + tests.conftest.freeze(original_update_run_updated_time, now=now_date()) ) mlrun.api.crud.Runs().store_run( db, self.run, self.run_uid, project=self.project @@ -326,9 +348,11 @@ def test_monitor_run_debouncing_non_terminal_state( # Mocking that update occurred before debounced period debounce_period = config.runs_monitoring_interval - mlrun.api.utils.singletons.db.get_db()._update_run_updated_time = tests.conftest.freeze( - original_update_run_updated_time, - now=now_date() - timedelta(seconds=float(2 * debounce_period)), + mlrun.api.utils.singletons.db.get_db()._update_run_updated_time = ( + tests.conftest.freeze( + original_update_run_updated_time, + now=now_date() - timedelta(seconds=float(2 * debounce_period)), + ) ) mlrun.api.crud.Runs().store_run( db, self.run, self.run_uid, project=self.project @@ -365,7 +389,11 @@ def test_monitor_run_debouncing_non_terminal_state( ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.completed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.completed_job_pod.metadata.name, ) def test_monitor_run_run_does_not_exists(self, db: Session, client: TestClient): @@ -390,7 +418,11 @@ def test_monitor_run_run_does_not_exists(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.completed_job_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.completed_job_pod.metadata.name, ) def _mock_list_resources_pods(self, pod=None): diff --git a/tests/api/runtime_handlers/test_mpijob.py b/tests/api/runtime_handlers/test_mpijob.py index b8bc73c2c4..b12089b21b 100644 --- a/tests/api/runtime_handlers/test_mpijob.py +++ b/tests/api/runtime_handlers/test_mpijob.py @@ -18,15 +18,24 @@ def custom_setup(self): # initializing them here to save space in tests self.active_crd_dict = self._generate_mpijob_crd( - self.project, self.run_uid, self._get_active_crd_status(), + self.project, + self.run_uid, + self._get_active_crd_status(), ) self.succeeded_crd_dict = self._generate_mpijob_crd( - self.project, self.run_uid, self._get_succeeded_crd_status(), + self.project, + self.run_uid, + self._get_succeeded_crd_status(), ) self.failed_crd_dict = self._generate_mpijob_crd( - self.project, self.run_uid, self._get_failed_crd_status(), + self.project, + self.run_uid, + self._get_failed_crd_status(), + ) + self.no_status_crd_dict = self._generate_mpijob_crd( + self.project, + self.run_uid, ) - self.no_status_crd_dict = self._generate_mpijob_crd(self.project, self.run_uid,) launcher_pod_labels = { "group-name": "kubeflow.org", @@ -45,7 +54,9 @@ def custom_setup(self): launcher_pod_name = "trainer-1b019005-launcher" self.launcher_pod = self._generate_pod( - launcher_pod_name, launcher_pod_labels, PodPhases.running, + launcher_pod_name, + launcher_pod_labels, + PodPhases.running, ) worker_pod_labels = { @@ -65,7 +76,9 @@ def custom_setup(self): worker_pod_name = "trainer-1b019005-worker-0" self.worker_pod = self._generate_pod( - worker_pod_name, worker_pod_labels, PodPhases.running, + worker_pod_name, + worker_pod_labels, + PodPhases.running, ) self.pod_label_selector = self._generate_get_logger_pods_label_selector( @@ -134,7 +147,8 @@ def test_delete_resources_succeeded_crd(self, db: Session, client: TestClient): self.succeeded_crd_dict["metadata"]["namespace"], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -145,7 +159,11 @@ def test_delete_resources_succeeded_crd(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.launcher_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.launcher_pod.metadata.name, ) def test_delete_resources_running_crd(self, db: Session, client: TestClient): @@ -158,10 +176,12 @@ def test_delete_resources_running_crd(self, db: Session, client: TestClient): # nothing removed cause crd is active self._assert_delete_namespaced_custom_objects( - self.runtime_handler, [], + self.runtime_handler, + [], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) def test_delete_resources_with_grace_period(self, db: Session, client: TestClient): @@ -179,10 +199,12 @@ def test_delete_resources_with_grace_period(self, db: Session, client: TestClien # nothing removed cause grace period didn't pass self._assert_delete_namespaced_custom_objects( - self.runtime_handler, [], + self.runtime_handler, + [], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) def test_delete_resources_with_force(self, db: Session, client: TestClient): @@ -208,7 +230,8 @@ def test_delete_resources_with_force(self, db: Session, client: TestClient): self.active_crd_dict["metadata"]["namespace"], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -219,7 +242,11 @@ def test_delete_resources_with_force(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.running ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.launcher_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.launcher_pod.metadata.name, ) def test_monitor_run_succeeded_crd(self, db: Session, client: TestClient): @@ -241,7 +268,8 @@ def test_monitor_run_succeeded_crd(self, db: Session, client: TestClient): for _ in range(expected_monitor_cycles_to_reach_expected_state): self.runtime_handler.monitor_runs(get_db(), db) self._assert_list_namespaced_crds_calls( - self.runtime_handler, expected_number_of_list_crds_calls, + self.runtime_handler, + expected_number_of_list_crds_calls, ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -252,7 +280,11 @@ def test_monitor_run_succeeded_crd(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.launcher_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.launcher_pod.metadata.name, ) def test_monitor_run_failed_crd(self, db: Session, client: TestClient): @@ -274,7 +306,8 @@ def test_monitor_run_failed_crd(self, db: Session, client: TestClient): for _ in range(expected_monitor_cycles_to_reach_expected_state): self.runtime_handler.monitor_runs(get_db(), db) self._assert_list_namespaced_crds_calls( - self.runtime_handler, expected_number_of_list_crds_calls, + self.runtime_handler, + expected_number_of_list_crds_calls, ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -283,7 +316,11 @@ def test_monitor_run_failed_crd(self, db: Session, client: TestClient): ) self._assert_run_reached_state(db, self.project, self.run_uid, RunStates.error) self._assert_run_logs( - db, self.project, self.run_uid, log, self.launcher_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.launcher_pod.metadata.name, ) def _mock_list_resources_pods(self): diff --git a/tests/api/runtime_handlers/test_sparkjob.py b/tests/api/runtime_handlers/test_sparkjob.py index 5f26626cb1..6c2919e915 100644 --- a/tests/api/runtime_handlers/test_sparkjob.py +++ b/tests/api/runtime_handlers/test_sparkjob.py @@ -18,13 +18,19 @@ def custom_setup(self): # initializing them here to save space in tests self.running_crd_dict = self._generate_sparkjob_crd( - self.project, self.run_uid, self._get_running_crd_status(), + self.project, + self.run_uid, + self._get_running_crd_status(), ) self.completed_crd_dict = self._generate_sparkjob_crd( - self.project, self.run_uid, self._get_completed_crd_status(), + self.project, + self.run_uid, + self._get_completed_crd_status(), ) self.failed_crd_dict = self._generate_sparkjob_crd( - self.project, self.run_uid, self._get_failed_crd_status(), + self.project, + self.run_uid, + self._get_failed_crd_status(), ) executor_pod_labels = { @@ -46,7 +52,9 @@ def custom_setup(self): executor_pod_name = "my-spark-jdbc-2ea432f1-1597760338437-exec-1" self.executor_pod = self._generate_pod( - executor_pod_name, executor_pod_labels, PodPhases.running, + executor_pod_name, + executor_pod_labels, + PodPhases.running, ) driver_pod_labels = { @@ -67,7 +75,9 @@ def custom_setup(self): driver_pod_name = "my-spark-jdbc-2ea432f1-driver" self.driver_pod = self._generate_pod( - driver_pod_name, driver_pod_labels, PodPhases.running, + driver_pod_name, + driver_pod_labels, + PodPhases.running, ) self.pod_label_selector = self._generate_get_logger_pods_label_selector( @@ -130,7 +140,8 @@ def test_delete_resources_completed_crd(self, db: Session, client: TestClient): self.completed_crd_dict["metadata"]["namespace"], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -141,7 +152,11 @@ def test_delete_resources_completed_crd(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.driver_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.driver_pod.metadata.name, ) def test_delete_resources_running_crd(self, db: Session, client: TestClient): @@ -155,10 +170,12 @@ def test_delete_resources_running_crd(self, db: Session, client: TestClient): # nothing removed cause crd is running self._assert_delete_namespaced_custom_objects( - self.runtime_handler, [], + self.runtime_handler, + [], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) def test_delete_resources_with_grace_period(self, db: Session, client: TestClient): @@ -177,10 +194,12 @@ def test_delete_resources_with_grace_period(self, db: Session, client: TestClien # nothing removed cause grace period didn't pass self._assert_delete_namespaced_custom_objects( - self.runtime_handler, [], + self.runtime_handler, + [], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) def test_delete_resources_with_force(self, db: Session, client: TestClient): @@ -207,7 +226,8 @@ def test_delete_resources_with_force(self, db: Session, client: TestClient): self.running_crd_dict["metadata"]["namespace"], ) self._assert_list_namespaced_crds_calls( - self.runtime_handler, len(list_namespaced_crds_calls), + self.runtime_handler, + len(list_namespaced_crds_calls), ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -218,7 +238,11 @@ def test_delete_resources_with_force(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.running ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.driver_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.driver_pod.metadata.name, ) def test_monitor_run_completed_crd(self, db: Session, client: TestClient): @@ -240,7 +264,8 @@ def test_monitor_run_completed_crd(self, db: Session, client: TestClient): for _ in range(expected_monitor_cycles_to_reach_expected_state): self.runtime_handler.monitor_runs(get_db(), db) self._assert_list_namespaced_crds_calls( - self.runtime_handler, expected_number_of_list_crds_calls, + self.runtime_handler, + expected_number_of_list_crds_calls, ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -251,7 +276,11 @@ def test_monitor_run_completed_crd(self, db: Session, client: TestClient): db, self.project, self.run_uid, RunStates.completed ) self._assert_run_logs( - db, self.project, self.run_uid, log, self.driver_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.driver_pod.metadata.name, ) def test_monitor_run_failed_crd(self, db: Session, client: TestClient): @@ -273,7 +302,8 @@ def test_monitor_run_failed_crd(self, db: Session, client: TestClient): for _ in range(expected_monitor_cycles_to_reach_expected_state): self.runtime_handler.monitor_runs(get_db(), db) self._assert_list_namespaced_crds_calls( - self.runtime_handler, expected_number_of_list_crds_calls, + self.runtime_handler, + expected_number_of_list_crds_calls, ) self._assert_list_namespaced_pods_calls( self.runtime_handler, @@ -282,7 +312,11 @@ def test_monitor_run_failed_crd(self, db: Session, client: TestClient): ) self._assert_run_reached_state(db, self.project, self.run_uid, RunStates.error) self._assert_run_logs( - db, self.project, self.run_uid, log, self.driver_pod.metadata.name, + db, + self.project, + self.run_uid, + log, + self.driver_pod.metadata.name, ) def _mock_list_resources_pods(self): diff --git a/tests/api/runtimes/base.py b/tests/api/runtimes/base.py index 6b115113e6..4ef3d15905 100644 --- a/tests/api/runtimes/base.py +++ b/tests/api/runtimes/base.py @@ -317,7 +317,9 @@ def _assert_function_config( ) if expected_labels: diff_result = deepdiff.DeepDiff( - function_metadata["labels"], expected_labels, ignore_order=True, + function_metadata["labels"], + expected_labels, + ignore_order=True, ) # We just care that the values we look for are fully there. diff_result.pop("dictionary_item_removed", None) @@ -542,7 +544,9 @@ def _assert_pod_creation_config( if expected_node_selector: assert ( deepdiff.DeepDiff( - pod.spec.node_selector, expected_node_selector, ignore_order=True, + pod.spec.node_selector, + expected_node_selector, + ignore_order=True, ) == {} ) diff --git a/tests/api/runtimes/test_dask.py b/tests/api/runtimes/test_dask.py index 705124441f..b714e32f16 100644 --- a/tests/api/runtimes/test_dask.py +++ b/tests/api/runtimes/test_dask.py @@ -84,7 +84,9 @@ def _generate_runtime(self): return dask_cluster - def _assert_scheduler_pod_args(self,): + def _assert_scheduler_pod_args( + self, + ): scheduler_pod = self._get_scheduler_pod_creation_args() scheduler_container_spec = scheduler_pod.spec.containers[0] assert scheduler_container_spec.args == ["dask-scheduler"] @@ -146,7 +148,8 @@ def test_dask_runtime_with_resources(self, db: Session, client: TestClient): cpu=expected_scheduler_limits["cpu"], ) runtime.with_worker_limits( - mem=expected_worker_limits["memory"], cpu=expected_worker_limits["cpu"], + mem=expected_worker_limits["memory"], + cpu=expected_worker_limits["cpu"], ) runtime.gpus(expected_gpus, gpu_type) _ = runtime.client diff --git a/tests/api/runtimes/test_kubejob.py b/tests/api/runtimes/test_kubejob.py index 94f053298d..e134e4e715 100644 --- a/tests/api/runtimes/test_kubejob.py +++ b/tests/api/runtimes/test_kubejob.py @@ -212,8 +212,10 @@ def test_run_with_k8s_secrets(self, db: Session, k8s_secrets_mock: K8sSecretsMoc # We don't expect the internal secret to be visible - the user cannot mount it to the function # even if specifically asking for it in with_secrets() - expected_env_from_secrets = k8s_secrets_mock.get_expected_env_variables_from_secrets( - self.project, include_internal=False + expected_env_from_secrets = ( + k8s_secrets_mock.get_expected_env_variables_from_secrets( + self.project, include_internal=False + ) ) self._assert_pod_creation_config( @@ -331,7 +333,9 @@ def test_with_requirements(self, db: Session, client: TestClient): expected_commands = ["python -m pip install faker python-dotenv"] assert ( deepdiff.DeepDiff( - expected_commands, runtime.spec.build.commands, ignore_order=True, + expected_commands, + runtime.spec.build.commands, + ignore_order=True, ) == {} ) diff --git a/tests/api/runtimes/test_nuclio.py b/tests/api/runtimes/test_nuclio.py index d0ce4da5c1..e4e6ba0170 100644 --- a/tests/api/runtimes/test_nuclio.py +++ b/tests/api/runtimes/test_nuclio.py @@ -331,7 +331,9 @@ def test_enrich_with_ingress_on_cluster_ip(self, db: Session, client: TestClient function_name, project_name, config = compile_function_config(function) service_type = "ClusterIP" enrich_function_with_ingress( - config, NuclioIngressAddTemplatedIngressModes.on_cluster_ip, service_type, + config, + NuclioIngressAddTemplatedIngressModes.on_cluster_ip, + service_type, ) ingresses = resolve_function_ingresses(config["spec"]) assert ingresses[0]["hostTemplate"] != "" @@ -449,7 +451,7 @@ def test_deploy_image_name_and_build_base_image( self, db: Session, k8s_secrets_mock: K8sSecretsMock ): """When spec.image and also spec.build.base_image are both defined the spec.image should be applied - to spec.baseImage in nuclio.""" + to spec.baseImage in nuclio.""" function = self._generate_runtime(self.runtime_kind) function.spec.build.base_image = "mlrun/base_mlrun:latest" diff --git a/tests/api/utils/auth/providers/test_opa.py b/tests/api/utils/auth/providers/test_opa.py index c4755e9794..d84b8be210 100644 --- a/tests/api/utils/auth/providers/test_opa.py +++ b/tests/api/utils/auth/providers/test_opa.py @@ -36,7 +36,9 @@ async def permission_filter_path() -> str: @pytest.fixture() async def opa_provider( - api_url: str, permission_query_path: str, permission_filter_path: str, + api_url: str, + permission_query_path: str, + permission_filter_path: str, ) -> mlrun.api.utils.auth.providers.opa.Provider: mlrun.mlconf.httpdb.authorization.opa.log_level = 10 mlrun.mlconf.httpdb.authorization.mode = "opa" @@ -133,7 +135,9 @@ def mock_filter_query_success(request, context): ) assert ( deepdiff.DeepDiff( - expected_allowed_resources, allowed_resources, ignore_order=True, + expected_allowed_resources, + allowed_resources, + ignore_order=True, ) == {} ) diff --git a/tests/api/utils/clients/test_iguazio.py b/tests/api/utils/clients/test_iguazio.py index e9b3393edf..f14d3ab293 100644 --- a/tests/api/utils/clients/test_iguazio.py +++ b/tests/api/utils/clients/test_iguazio.py @@ -24,7 +24,9 @@ async def api_url() -> str: @pytest.fixture() -async def iguazio_client(api_url: str,) -> mlrun.api.utils.clients.iguazio.Client: +async def iguazio_client( + api_url: str, +) -> mlrun.api.utils.clients.iguazio.Client: client = mlrun.api.utils.clients.iguazio.Client() # force running init again so the configured api url will be used client.__init__() @@ -184,7 +186,11 @@ def _get_or_create_access_key_mock(status_code, request, context): } } assert ( - deepdiff.DeepDiff(expected_request_body, request.json(), ignore_order=True,) + deepdiff.DeepDiff( + expected_request_body, + request.json(), + ignore_order=True, + ) == {} ) return {"data": {"id": access_key_id}} @@ -235,9 +241,13 @@ def verify_get(request, context): # mock project response so store will update requests_mock.get( - f"{api_url}/api/projects/__name__/{project.metadata.name}", json=verify_get, + f"{api_url}/api/projects/__name__/{project.metadata.name}", + json=verify_get, + ) + project_owner = iguazio_client.get_project_owner( + session, + project.metadata.name, ) - project_owner = iguazio_client.get_project_owner(session, project.metadata.name,) assert project_owner.username == owner_username assert project_owner.session == owner_access_key @@ -269,10 +279,12 @@ def verify_list(request, context): # mock project response so store will update requests_mock.get( - f"{api_url}/api/projects", json=verify_list, + f"{api_url}/api/projects", + json=verify_list, ) iguazio_client.list_projects( - session, updated_after, + session, + updated_after, ) @@ -381,7 +393,8 @@ def test_create_project_failures( mlrun.errors.MLRunBadRequestError, match=rf"(.*){error_message}(.*)" ): iguazio_client.create_project( - session, project, + session, + project, ) # mock job failure - with nice error message in result @@ -410,7 +423,8 @@ def test_create_project_failures( mlrun.errors.MLRunBadRequestError, match=rf"(.*){error_message}(.*)" ): iguazio_client.create_project( - session, project, + session, + project, ) # mock job failure - without nice error message (shouldn't happen, but let's test) @@ -424,7 +438,8 @@ def test_create_project_failures( with pytest.raises(mlrun.errors.MLRunRuntimeError): iguazio_client.create_project( - session, project, + session, + project, ) @@ -434,7 +449,9 @@ def test_create_project_minimal_project( requests_mock: requests_mock_package.Mocker, ): project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="some-name",), + metadata=mlrun.api.schemas.ProjectMetadata( + name="some-name", + ), ) _create_project_and_assert(api_url, iguazio_client, requests_mock, project) @@ -479,7 +496,9 @@ def verify_store_update(request, context): json=verify_store_update, ) iguazio_client.update_project( - session, project.metadata.name, project, + session, + project.metadata.name, + project, ) @@ -538,7 +557,8 @@ def test_delete_project_without_wait( def test_format_as_leader_project( - api_url: str, iguazio_client: mlrun.api.utils.clients.iguazio.Client, + api_url: str, + iguazio_client: mlrun.api.utils.clients.iguazio.Client, ): project = _generate_project() iguazio_project = iguazio_client.format_as_leader_project(project) @@ -623,7 +643,10 @@ def _create_project_and_assert( f"{api_url}/api/projects/__name__/{project.metadata.name}", json={"data": _build_project_response(iguazio_client, project)}, ) - is_running_in_background = iguazio_client.create_project(session, project,) + is_running_in_background = iguazio_client.create_project( + session, + project, + ) assert is_running_in_background is False assert mocker.call_count == num_of_calls_until_completion @@ -724,7 +747,9 @@ def _generate_project( owner=owner, some_extra_field="some value", ), - status=mlrun.api.schemas.ProjectStatus(some_extra_field="some value",), + status=mlrun.api.schemas.ProjectStatus( + some_extra_field="some value", + ), ) diff --git a/tests/api/utils/clients/test_nuclio.py b/tests/api/utils/clients/test_nuclio.py index ae2dc341ee..2c96dff7f7 100644 --- a/tests/api/utils/clients/test_nuclio.py +++ b/tests/api/utils/clients/test_nuclio.py @@ -18,7 +18,9 @@ async def api_url() -> str: @pytest.fixture() -async def nuclio_client(api_url: str,) -> mlrun.api.utils.clients.nuclio.Client: +async def nuclio_client( + api_url: str, +) -> mlrun.api.utils.clients.nuclio.Client: client = mlrun.api.utils.clients.nuclio.Client() # force running init again so the configured api url will be used client.__init__() @@ -50,12 +52,18 @@ def test_get_project( assert project.metadata.name == project_name assert project.spec.description == project_description assert ( - deepdiff.DeepDiff(project_labels, project.metadata.labels, ignore_order=True,) + deepdiff.DeepDiff( + project_labels, + project.metadata.labels, + ignore_order=True, + ) == {} ) assert ( deepdiff.DeepDiff( - project_annotations, project.metadata.annotations, ignore_order=True, + project_annotations, + project.metadata.annotations, + ignore_order=True, ) == {} ) @@ -297,7 +305,12 @@ def verify_patch(request, context): expected_body["metadata"]["labels"].update(project_labels) expected_body["metadata"]["annotations"].update(project_annotations) assert ( - deepdiff.DeepDiff(expected_body, request.json(), ignore_order=True,) == {} + deepdiff.DeepDiff( + expected_body, + request.json(), + ignore_order=True, + ) + == {} ) context.status_code = http.HTTPStatus.NO_CONTENT.value @@ -325,7 +338,8 @@ def test_patch_project_only_labels( "some-label": "some-label-value", } mocked_project_body = _generate_project_body( - project_name, labels={"label-key": "label-value"}, + project_name, + labels={"label-key": "label-value"}, ) def verify_patch(request, context): @@ -333,7 +347,12 @@ def verify_patch(request, context): expected_body = mocked_project_body expected_body["metadata"]["labels"].update(project_labels) assert ( - deepdiff.DeepDiff(expected_body, request.json(), ignore_order=True,) == {} + deepdiff.DeepDiff( + expected_body, + request.json(), + ignore_order=True, + ) + == {} ) context.status_code = http.HTTPStatus.NO_CONTENT.value @@ -342,7 +361,9 @@ def verify_patch(request, context): ) requests_mock.put(f"{api_url}/api/projects", json=verify_patch) nuclio_client.patch_project( - None, project_name, {"metadata": {"labels": project_labels}}, + None, + project_name, + {"metadata": {"labels": project_labels}}, ) diff --git a/tests/api/utils/projects/test_follower_member.py b/tests/api/utils/projects/test_follower_member.py index bbf4fd6438..d6f85a6c5d 100644 --- a/tests/api/utils/projects/test_follower_member.py +++ b/tests/api/utils/projects/test_follower_member.py @@ -82,7 +82,10 @@ def test_create_project( nop_leader: mlrun.api.utils.projects.remotes.leader.Member, ): project = _generate_project() - created_project, _ = projects_follower.create_project(None, project,) + created_project, _ = projects_follower.create_project( + None, + project, + ) _assert_projects_equal(project, created_project) _assert_project_in_follower(projects_follower, project) @@ -96,7 +99,9 @@ def test_store_project( # project doesn't exist - store will create created_project, _ = projects_follower.store_project( - None, project.metadata.name, project, + None, + project.metadata.name, + project, ) _assert_projects_equal(project, created_project) _assert_project_in_follower(projects_follower, project) @@ -104,7 +109,9 @@ def test_store_project( project_update = _generate_project(description="new description") # project exists - store will update updated_project, _ = projects_follower.store_project( - None, project.metadata.name, project_update, + None, + project.metadata.name, + project_update, ) _assert_projects_equal(project_update, updated_project) _assert_project_in_follower(projects_follower, project_update) @@ -119,7 +126,9 @@ def test_patch_project( # project doesn't exist - store will create created_project, _ = projects_follower.store_project( - None, project.metadata.name, project, + None, + project.metadata.name, + project, ) _assert_projects_equal(project, created_project) _assert_project_in_follower(projects_follower, project) @@ -142,20 +151,23 @@ def test_delete_project( ): project = _generate_project() projects_follower.create_project( - None, project, + None, + project, ) _assert_project_in_follower(projects_follower, project) - mlrun.api.utils.singletons.db.get_db().verify_project_has_no_related_resources = unittest.mock.Mock( - return_value=None + mlrun.api.utils.singletons.db.get_db().verify_project_has_no_related_resources = ( + unittest.mock.Mock(return_value=None) ) projects_follower.delete_project( - None, project.metadata.name, + None, + project.metadata.name, ) _assert_project_not_in_follower(projects_follower, project.metadata.name) # make sure another delete doesn't fail projects_follower.delete_project( - None, project.metadata.name, + None, + project.metadata.name, ) @@ -166,7 +178,8 @@ def test_get_project( ): project = _generate_project() projects_follower.create_project( - None, project, + None, + project, ) # this functions uses get_project to assert, second assert will verify we're raising not found error _assert_project_in_follower(projects_follower, project) @@ -183,7 +196,8 @@ def test_get_project_owner( nop_leader.project_owner_session = owner_session project = _generate_project(owner=owner) projects_follower.create_project( - None, project, + None, + project, ) project_owner = projects_follower.get_project_owner(None, project.metadata.name) assert project_owner.username == owner @@ -223,7 +237,8 @@ def test_list_project( } for _project in all_projects.values(): projects_follower.create_project( - None, _project, + None, + _project, ) # list all _assert_list_projects(projects_follower, list(all_projects.values())) @@ -237,7 +252,9 @@ def test_list_project( # list by owner _assert_list_projects( - projects_follower, [project, archived_project], owner=owner, + projects_follower, + [project, archived_project], + owner=owner, ) # list specific names only @@ -249,7 +266,9 @@ def test_list_project( # list no valid names _assert_list_projects( - projects_follower, [], names=[], + projects_follower, + [], + names=[], ) # list labeled - key existence @@ -339,7 +358,11 @@ def test_list_project_leader_format( projects_role=mlrun.api.schemas.ProjectsRole.nop, ) assert ( - deepdiff.DeepDiff(projects.projects[0].data, project.dict(), ignore_order=True,) + deepdiff.DeepDiff( + projects.projects[0].data, + project.dict(), + ignore_order=True, + ) == {} ) @@ -364,7 +387,9 @@ def _assert_list_projects( assert len(projects.projects) == len(expected_projects) assert ( deepdiff.DeepDiff( - projects.projects, list(expected_projects_map.keys()), ignore_order=True, + projects.projects, + list(expected_projects_map.keys()), + ignore_order=True, ) == {} ) @@ -381,15 +406,24 @@ def _generate_project( return mlrun.api.schemas.Project( metadata=mlrun.api.schemas.ProjectMetadata(name=name, labels=labels), spec=mlrun.api.schemas.ProjectSpec( - description=description, desired_state=desired_state, owner=owner, + description=description, + desired_state=desired_state, + owner=owner, + ), + status=mlrun.api.schemas.ProjectStatus( + state=state, ), - status=mlrun.api.schemas.ProjectStatus(state=state,), ) def _assert_projects_equal(project_1, project_2): assert ( - deepdiff.DeepDiff(project_1.dict(), project_2.dict(), ignore_order=True,) == {} + deepdiff.DeepDiff( + project_1.dict(), + project_2.dict(), + ignore_order=True, + ) + == {} ) diff --git a/tests/api/utils/projects/test_leader_member.py b/tests/api/utils/projects/test_leader_member.py index 4ddfa0d6cb..28ab27bdda 100644 --- a/tests/api/utils/projects/test_leader_member.py +++ b/tests/api/utils/projects/test_leader_member.py @@ -63,7 +63,8 @@ def test_projects_sync_follower_project_adoption( spec=mlrun.api.schemas.ProjectSpec(description=project_description), ) nop_follower.create_project( - None, project, + None, + project, ) _assert_project_in_followers([nop_follower], project, enriched=False) _assert_no_projects_in_followers([leader_follower, second_nop_follower]) @@ -136,7 +137,8 @@ def test_projects_sync_leader_project_syncing( metadata=mlrun.api.schemas.ProjectMetadata(name=invalid_project_name), ) leader_follower.create_project( - None, invalid_project, + None, + invalid_project, ) _assert_project_in_followers([leader_follower], project, enriched=False) _assert_project_in_followers([leader_follower], invalid_project, enriched=False) @@ -147,7 +149,8 @@ def test_projects_sync_leader_project_syncing( [leader_follower, nop_follower, second_nop_follower], project ) _assert_project_not_in_followers( - [nop_follower, second_nop_follower], invalid_project_name, + [nop_follower, second_nop_follower], + invalid_project_name, ) @@ -175,13 +178,16 @@ def test_projects_sync_multiple_follower_project_adoption( ), ) nop_follower.create_project( - None, both_followers_project, + None, + both_followers_project, ) second_nop_follower.create_project( - None, both_followers_project, + None, + both_followers_project, ) second_nop_follower.create_project( - None, second_follower_project, + None, + second_follower_project, ) leader_follower.create_project = unittest.mock.Mock( wraps=leader_follower.create_project @@ -223,7 +229,8 @@ def test_create_project( ), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -272,21 +279,27 @@ def test_create_and_store_project_failure_invalid_name( ) if case["valid"]: projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower], project) projects_leader.store_project( - None, project_name, project, + None, + project_name, + project, ) _assert_project_in_followers([leader_follower], project) else: with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): projects_leader.create_project( - None, project, + None, + project, ) with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): projects_leader.store_project( - None, project_name, project, + None, + project_name, + project, ) _assert_project_not_in_followers([leader_follower], project_name) @@ -299,7 +312,8 @@ def test_ensure_project( ): project_name = "project-name" projects_leader.ensure_project( - None, project_name, + None, + project_name, ) project = mlrun.api.schemas.Project( metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), @@ -308,10 +322,12 @@ def test_ensure_project( # further calls should do nothing projects_leader.ensure_project( - None, project_name, + None, + project_name, ) projects_leader.ensure_project( - None, project_name, + None, + project_name, ) @@ -330,7 +346,9 @@ def test_store_project_creation( _assert_no_projects_in_followers([leader_follower, nop_follower]) projects_leader.store_project( - None, project_name, project, + None, + project_name, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -348,7 +366,8 @@ def test_store_project_update( spec=mlrun.api.schemas.ProjectSpec(description=project_description), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -361,7 +380,9 @@ def test_store_project_update( ) projects_leader.store_project( - None, project_name, updated_project, + None, + project_name, + updated_project, ) _assert_project_in_followers([leader_follower, nop_follower], updated_project) @@ -377,7 +398,8 @@ def test_patch_project( metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers( [leader_follower, nop_follower], project, enriched=False @@ -412,7 +434,8 @@ def test_store_and_patch_project_failure_conflict_body_path_name( metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -426,7 +449,9 @@ def test_store_and_patch_project_failure_conflict_body_path_name( ) with pytest.raises(mlrun.errors.MLRunConflictError): projects_leader.patch_project( - None, project_name, {"metadata": {"name": "different-name"}}, + None, + project_name, + {"metadata": {"name": "different-name"}}, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -442,7 +467,8 @@ def test_delete_project( metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -464,7 +490,8 @@ def mock_failed_delete(*args, **kwargs): metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -488,7 +515,8 @@ def test_list_projects( metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) @@ -519,13 +547,16 @@ def test_get_project( spec=mlrun.api.schemas.ProjectSpec(description=project_description), ) projects_leader.create_project( - None, project, + None, + project, ) _assert_project_in_followers([leader_follower, nop_follower], project) # change project description in follower nop_follower.patch_project( - None, project_name, {"spec": {"description": "updated description"}}, + None, + project_name, + {"spec": {"description": "updated description"}}, ) # assert get considers only the leader diff --git a/tests/api/utils/test_scheduler.py b/tests/api/utils/test_scheduler.py index 5e802a2cae..da30cc2dd0 100644 --- a/tests/api/utils/test_scheduler.py +++ b/tests/api/utils/test_scheduler.py @@ -165,7 +165,14 @@ async def test_invoke_schedule( response["data"]["metadata"]["uid"] for response in [response_1, response_2] ] db_uids = [run["metadata"]["uid"] for run in runs] - assert DeepDiff(response_uids, db_uids, ignore_order=True,) == {} + assert ( + DeepDiff( + response_uids, + db_uids, + ignore_order=True, + ) + == {} + ) schedule = scheduler.get_schedule(db, project, schedule_name, include_last_run=True) assert schedule.last_run is not None @@ -276,8 +283,8 @@ async def test_schedule_upgrade_from_scheduler_without_credentials_store( ) # stop scheduler, reconfigure to store credentials and start again (upgrade) await scheduler.stop() - mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = unittest.mock.Mock( - return_value=True + mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = ( + unittest.mock.Mock(return_value=True) ) await scheduler.start(db) @@ -285,8 +292,12 @@ async def test_schedule_upgrade_from_scheduler_without_credentials_store( # auth info, mock the functions for this username = "some-username" session = "some-session" - mlrun.api.utils.singletons.project_member.get_project_member().get_project_owner = unittest.mock.Mock( - return_value=mlrun.api.schemas.ProjectOwner(username=username, session=session) + mlrun.api.utils.singletons.project_member.get_project_member().get_project_owner = ( + unittest.mock.Mock( + return_value=mlrun.api.schemas.ProjectOwner( + username=username, session=session + ) + ) ) time_to_sleep = ( end_date - datetime.now() @@ -674,8 +685,8 @@ async def test_rescheduling_secrets_storing( scheduler: Scheduler, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): - mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = unittest.mock.Mock( - return_value=True + mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = ( + unittest.mock.Mock(return_value=True) ) name = "schedule-name" project = config.default_project @@ -725,8 +736,8 @@ async def test_schedule_crud_secrets_handling( scheduler: Scheduler, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): - mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = unittest.mock.Mock( - return_value=True + mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = ( + unittest.mock.Mock(return_value=True) ) for schedule_name in ["valid-secret-key", "invalid/secret/key"]: project = config.default_project @@ -771,7 +782,9 @@ async def test_schedule_crud_secrets_handling( # delete schedule scheduler.delete_schedule( - db, project, schedule_name, + db, + project, + schedule_name, ) _assert_schedule_secrets(scheduler, project, schedule_name, None, None) @@ -782,16 +795,16 @@ async def test_schedule_access_key_generation( scheduler: Scheduler, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): - mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = unittest.mock.Mock( - return_value=True + mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = ( + unittest.mock.Mock(return_value=True) ) project = config.default_project schedule_name = "schedule-name" scheduled_object = _create_mlrun_function_and_matching_scheduled_object(db, project) cron_trigger = schemas.ScheduleCronTrigger(year="1999") access_key = "generated-access-key" - mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key = unittest.mock.Mock( - return_value=access_key + mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key = ( + unittest.mock.Mock(return_value=access_key) ) scheduler.create_schedule( db, @@ -806,8 +819,8 @@ async def test_schedule_access_key_generation( _assert_schedule_secrets(scheduler, project, schedule_name, None, access_key) access_key = "generated-access-key-2" - mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key = unittest.mock.Mock( - return_value=access_key + mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key = ( + unittest.mock.Mock(return_value=access_key) ) scheduler.update_schedule( db, @@ -867,7 +880,11 @@ async def test_update_schedule( # update labels scheduler.update_schedule( - db, mlrun.api.schemas.AuthInfo(), project, schedule_name, labels=labels_2, + db, + mlrun.api.schemas.AuthInfo(), + project, + schedule_name, + labels=labels_2, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -883,7 +900,10 @@ async def test_update_schedule( # update nothing scheduler.update_schedule( - db, mlrun.api.schemas.AuthInfo(), project, schedule_name, + db, + mlrun.api.schemas.AuthInfo(), + project, + schedule_name, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -899,7 +919,11 @@ async def test_update_schedule( # update labels to empty dict scheduler.update_schedule( - db, mlrun.api.schemas.AuthInfo(), project, schedule_name, labels={}, + db, + mlrun.api.schemas.AuthInfo(), + project, + schedule_name, + labels={}, ) schedule = scheduler.get_schedule(db, project, schedule_name) @@ -920,7 +944,9 @@ async def test_update_schedule( ) # this way we're leaving ourselves one second to create the schedule preventing transient test failure cron_trigger = schemas.ScheduleCronTrigger( - second="*/1", start_date=start_date, end_date=end_date, + second="*/1", + start_date=start_date, + end_date=end_date, ) scheduler.update_schedule( db, @@ -1041,7 +1067,10 @@ def _assert_schedule_get_and_list_credentials_enrichment( expected_access_key: str, ): schedule = scheduler.get_schedule( - db, project, schedule_name, include_credentials=True, + db, + project, + schedule_name, + include_credentials=True, ) assert schedule.credentials.access_key == expected_access_key schedules = scheduler.list_schedules( @@ -1057,11 +1086,11 @@ def _assert_schedule_secrets( expected_username: str, expected_access_key: str, ): - access_key_secret_key = mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key( - schedule_name + access_key_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_access_key_secret_key(schedule_name) ) - username_secret_key = mlrun.api.crud.Secrets().generate_schedule_username_secret_key( - schedule_name + username_secret_key = ( + mlrun.api.crud.Secrets().generate_schedule_username_secret_key(schedule_name) ) key_map_secret_key = mlrun.api.crud.Secrets().generate_schedule_key_map_secret_key() secret_value = mlrun.api.crud.Secrets().get_secret( diff --git a/tests/artifacts/test_dataset.py b/tests/artifacts/test_dataset.py index 00dc70ed18..1794448fbf 100644 --- a/tests/artifacts/test_dataset.py +++ b/tests/artifacts/test_dataset.py @@ -77,7 +77,9 @@ def test_dataset_upload_with_src_path_filling_hash(): src_path = pathlib.Path(tests.conftest.results) / "dataset" target_path = pathlib.Path(tests.conftest.results) / "target-dataset" artifact = mlrun.artifacts.dataset.DatasetArtifact( - df=data_frame, target_path=str(target_path), format="csv", + df=data_frame, + target_path=str(target_path), + format="csv", ) data_frame.to_csv(src_path) artifact.src_path = src_path @@ -89,6 +91,8 @@ def _generate_dataset_artifact(format_): data_frame = pandas.DataFrame({"x": [1, 2]}) target_path = pathlib.Path(tests.conftest.results) / "dataset" artifact = mlrun.artifacts.dataset.DatasetArtifact( - df=data_frame, target_path=str(target_path), format=format_, + df=data_frame, + target_path=str(target_path), + format=format_, ) return artifact diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index c185f1de6e..b546aefbac 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -165,7 +165,10 @@ def remote_builder( self, func, with_mlrun, mlrun_version_specifier=None, skip_deployed=False ): self._function = func.to_dict() - status = NuclioStatus(state="ready", nuclio_name="test-nuclio-name",) + status = NuclioStatus( + state="ready", + nuclio_name="test-nuclio-name", + ) return {"data": {"status": status.to_dict()}} def get_builder_status( diff --git a/tests/integration/sdk_api/artifacts/test_artifact_tags.py b/tests/integration/sdk_api/artifacts/test_artifact_tags.py index 9d4786b8f4..c94f2d783c 100644 --- a/tests/integration/sdk_api/artifacts/test_artifact_tags.py +++ b/tests/integration/sdk_api/artifacts/test_artifact_tags.py @@ -27,4 +27,11 @@ def test_list_artifact_tags(self): key, artifact.to_dict(), uid_2, tag=tag_2, project=project_name ) artifact_tags = mlrun.get_run_db().list_artifact_tags(project_name) - assert deepdiff.DeepDiff(artifact_tags, [tag, tag_2], ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + artifact_tags, + [tag, tag_2], + ignore_order=True, + ) + == {} + ) diff --git a/tests/integration/sdk_api/base.py b/tests/integration/sdk_api/base.py index e7dcac92fc..1c2c78ff44 100644 --- a/tests/integration/sdk_api/base.py +++ b/tests/integration/sdk_api/base.py @@ -100,9 +100,14 @@ def _teardown_env(self): def _run_db(self): self._logger.debug("Starting DataBase") self._run_command( - "make", args=["run-test-db"], cwd=TestMLRunIntegration.root_path, + "make", + args=["run-test-db"], + cwd=TestMLRunIntegration.root_path, + ) + output = self._run_command( + "docker", + args=["ps", "--last", "1", "-q"], ) - output = self._run_command("docker", args=["ps", "--last", "1", "-q"],) self.db_container_id = output.strip() self._logger.debug("Started DataBase", container_id=self.db_container_id) @@ -119,7 +124,10 @@ def _run_api(self): ), cwd=TestMLRunIntegration.root_path, ) - output = self._run_command("docker", args=["ps", "--last", "1", "-q"],) + output = self._run_command( + "docker", + args=["ps", "--last", "1", "-q"], + ) self.api_container_id = output.strip() # retrieve container bind port + host output = self._run_command( @@ -155,7 +163,8 @@ def _remove_db(self): "docker", args=["rm", "--force", self.db_container_id] ) self._logger.debug( - "Removed Database container", out=out, + "Removed Database container", + out=out, ) def _ensure_database_liveness(self, retry_interval=2, timeout=30): diff --git a/tests/integration/sdk_api/feature_store/test_feature_store.py b/tests/integration/sdk_api/feature_store/test_feature_store.py index d7db2cb5ac..056c3e14d9 100644 --- a/tests/integration/sdk_api/feature_store/test_feature_store.py +++ b/tests/integration/sdk_api/feature_store/test_feature_store.py @@ -19,5 +19,6 @@ def test_deploy_ingestion_service_without_preview(self): with pytest.raises(mlrun.errors.MLRunNotFoundError): fs.deploy_ingestion_service( - featureset=fset, source=v3io_source, + featureset=fset, + source=v3io_source, ) diff --git a/tests/integration/sdk_api/httpdb/test_exception_handling.py b/tests/integration/sdk_api/httpdb/test_exception_handling.py index d806801607..228b3a146d 100644 --- a/tests/integration/sdk_api/httpdb/test_exception_handling.py +++ b/tests/integration/sdk_api/httpdb/test_exception_handling.py @@ -18,7 +18,7 @@ def test_exception_handling(self): # This is practically verifies that log_and_raise puts the kwargs under the details.reason with pytest.raises( mlrun.errors.MLRunNotFoundError, - match=fr"404 Client Error: Not Found for url: http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}" + match=rf"404 Client Error: Not Found for url: http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}" r"\/files\?path=file%3A%2F%2F%2Fpath%2Fdoes%2F" r"not%2Fexist: details: {'reason': {'path': 'file:\/\/\/path\/does\/not\/exist', 'err': \"\[Errno 2] No suc" r"h file or directory: '\/path\/does\/not\/exist'\"}}", @@ -36,7 +36,7 @@ def test_exception_handling(self): ) with pytest.raises( mlrun.errors.MLRunBadRequestError, - match=fr"400 Client Error: Bad Request for url: http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}" + match=rf"400 Client Error: Bad Request for url: http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}" r"\/projects: Failed creating project some_p" r"roject details: {'reason': 'MLRunInvalidArgumentError\(\"Field \\'project\.metadata\.name\\' is malformed" r"\. Does not match required pattern: (.*)\"\)'}", @@ -50,7 +50,7 @@ def test_exception_handling(self): with pytest.raises( mlrun.errors.MLRunHTTPError, match=r"422 Client Error: Unprocessable Entity for url: " - fr"http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project-name: " + rf"http:\/\/(.*)\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project-name: " r"Failed deleting project some-project-name details: \[{'loc':" r" \['header', 'x-mlrun-deletion-strategy'], 'msg': \"value is not a valid enumeration member; " r"permitted: 'restrict', 'restricted', 'cascade', 'cascading', 'check'\", 'type': 'type_error.enum'," @@ -66,7 +66,7 @@ def test_exception_handling(self): with pytest.raises( mlrun.errors.MLRunInternalServerError, match=r"500 Server Error: Internal Server Error for url: http:\/\/(.*)" - fr"\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project\/model-" + rf"\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project\/model-" r"endpoints\?start=now-1h&end=now&top-level=False: details: {\'reason\': \"ValueError\(\'Access key must be" r" provided in Client\(\) arguments or in the V3IO_ACCESS_KEY environment variable\'\)\"}", ): @@ -79,7 +79,7 @@ def test_exception_handling(self): with pytest.raises( mlrun.errors.MLRunRuntimeError, match=r"HTTPConnectionPool\(host='does-not-exist', port=80\): Max retries exceeded with url: " - fr"\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project \(Caused by NewConnectionError" + rf"\/{mlrun.get_run_db().get_api_path_prefix()}\/projects\/some-project \(Caused by NewConnectionError" r"\(': Failed to establish a new connection:" r" \[Errno (.*)'\)\): Failed retrieving project some-project", ): diff --git a/tests/platforms/test_iguazio.py b/tests/platforms/test_iguazio.py index ec22e01906..8232070f33 100644 --- a/tests/platforms/test_iguazio.py +++ b/tests/platforms/test_iguazio.py @@ -101,12 +101,18 @@ def test_mount_v3io_legacy(): } expected_volume_mount = {"mountPath": "/User", "name": "v3io", "subPath": ""} assert ( - deepdiff.DeepDiff([expected_volume], function.spec.volumes, ignore_order=True,) + deepdiff.DeepDiff( + [expected_volume], + function.spec.volumes, + ignore_order=True, + ) == {} ) assert ( deepdiff.DeepDiff( - [expected_volume_mount], function.spec.volume_mounts, ignore_order=True, + [expected_volume_mount], + function.spec.volume_mounts, + ignore_order=True, ) == {} ) diff --git a/tests/platforms/test_other.py b/tests/platforms/test_other.py index 5235eab65b..deb3853643 100644 --- a/tests/platforms/test_other.py +++ b/tests/platforms/test_other.py @@ -20,12 +20,18 @@ def test_mount_configmap(): ) assert ( - deepdiff.DeepDiff([expected_volume], function.spec.volumes, ignore_order=True,) + deepdiff.DeepDiff( + [expected_volume], + function.spec.volumes, + ignore_order=True, + ) == {} ) assert ( deepdiff.DeepDiff( - [expected_volume_mount], function.spec.volume_mounts, ignore_order=True, + [expected_volume_mount], + function.spec.volume_mounts, + ignore_order=True, ) == {} ) @@ -45,12 +51,18 @@ def test_mount_hostpath(): ) assert ( - deepdiff.DeepDiff([expected_volume], function.spec.volumes, ignore_order=True,) + deepdiff.DeepDiff( + [expected_volume], + function.spec.volumes, + ignore_order=True, + ) == {} ) assert ( deepdiff.DeepDiff( - [expected_volume_mount], function.spec.volume_mounts, ignore_order=True, + [expected_volume_mount], + function.spec.volume_mounts, + ignore_order=True, ) == {} ) diff --git a/tests/projects/test_project.py b/tests/projects/test_project.py index 1ea6d8d0a4..9ed6d1edaa 100644 --- a/tests/projects/test_project.py +++ b/tests/projects/test_project.py @@ -70,24 +70,44 @@ def test_create_project_from_file_with_legacy_structure(): assert project.spec.artifact_path == artifact_path # assert accessible from the project as well assert project.artifact_path == artifact_path - assert deepdiff.DeepDiff(params, project.spec.params, ignore_order=True,) == {} + assert ( + deepdiff.DeepDiff( + params, + project.spec.params, + ignore_order=True, + ) + == {} + ) # assert accessible from the project as well - assert deepdiff.DeepDiff(params, project.params, ignore_order=True,) == {} assert ( deepdiff.DeepDiff( - legacy_project.functions, project.functions, ignore_order=True, + params, + project.params, + ignore_order=True, + ) + == {} + ) + assert ( + deepdiff.DeepDiff( + legacy_project.functions, + project.functions, + ignore_order=True, ) == {} ) assert ( deepdiff.DeepDiff( - legacy_project.workflows, project.workflows, ignore_order=True, + legacy_project.workflows, + project.workflows, + ignore_order=True, ) == {} ) assert ( deepdiff.DeepDiff( - legacy_project.artifacts, project.artifacts, ignore_order=True, + legacy_project.artifacts, + project.artifacts, + ignore_order=True, ) == {} ) @@ -207,7 +227,10 @@ def test_function_run_cli(): function_path = pathlib.Path(__file__).parent / "assets" / "handler.py" project = mlrun.new_project("run-cli", str(project_dir_path)) project.set_function( - str(function_path), "my-func", image="mlrun/mlrun", handler="myhandler", + str(function_path), + "my-func", + image="mlrun/mlrun", + handler="myhandler", ) project.export() diff --git a/tests/projects/workflow.py b/tests/projects/workflow.py index 4e7471a6fd..93255eda01 100644 --- a/tests/projects/workflow.py +++ b/tests/projects/workflow.py @@ -8,8 +8,11 @@ def kfpipeline(): # analyze our dataset funcs["describe"].as_step( - name="summary", params={"label_column": "labels"}, + name="summary", + params={"label_column": "labels"}, ) # train with hyper-paremeters - funcs["trainer-function"].as_step(name="trainer-function",) + funcs["trainer-function"].as_step( + name="trainer-function", + ) diff --git a/tests/run/test_main.py b/tests/run/test_main.py index 9eb2fb3020..30a7588f87 100644 --- a/tests/run/test_main.py +++ b/tests/run/test_main.py @@ -101,7 +101,11 @@ def test_main_run_args_from_env(): '"metadata":{"uid":"123459", "name":"tst", "labels": {"kind": "job"}}}' ) - out = exec_run("'main.py -x {x}'", ["--from-env"], "test_main_run_args_from_env",) + out = exec_run( + "'main.py -x {x}'", + ["--from-env"], + "test_main_run_args_from_env", + ) db = mlrun.get_run_db() run = db.read_run("123459") print(out) diff --git a/tests/run/test_run.py b/tests/run/test_run.py index c041063d08..be882b7b0b 100644 --- a/tests/run/test_run.py +++ b/tests/run/test_run.py @@ -211,7 +211,8 @@ def test_args_integrity(): spec = tag_test(base_spec, "test_local_no_context") spec.spec.parameters = {"xyz": "789"} result = new_function( - command=f"{tests_root_directory}/no_ctx.py", args=["It's", "a", "nice", "day!"], + command=f"{tests_root_directory}/no_ctx.py", + args=["It's", "a", "nice", "day!"], ).run(spec) verify_state(result) diff --git a/tests/rundb/test_httpdb.py b/tests/rundb/test_httpdb.py index 9619b5f3f7..58e9c8cd92 100644 --- a/tests/rundb/test_httpdb.py +++ b/tests/rundb/test_httpdb.py @@ -549,7 +549,9 @@ def test_project_file_db_roundtrip(create_server): labels = {"key": "value"} annotations = {"annotation-key": "annotation-value"} project_metadata = mlrun.projects.project.ProjectMetadata( - project_name, labels=labels, annotations=annotations, + project_name, + labels=labels, + annotations=annotations, ) project_spec = mlrun.projects.project.ProjectSpec( description, diff --git a/tests/rundb/test_sqldb.py b/tests/rundb/test_sqldb.py index bd9618c3ce..c39117ffcf 100644 --- a/tests/rundb/test_sqldb.py +++ b/tests/rundb/test_sqldb.py @@ -183,7 +183,9 @@ def test_projects_crud(db: SQLDB, db_session: Session): project_output = db.get_project(db_session, name=project.metadata.name) assert ( deepdiff.DeepDiff( - project.dict(), project_output.dict(exclude={"id"}), ignore_order=True, + project.dict(), + project_output.dict(exclude={"id"}), + ignore_order=True, ) == {} ) diff --git a/tests/rundb/workflow.py b/tests/rundb/workflow.py index 4e7471a6fd..93255eda01 100644 --- a/tests/rundb/workflow.py +++ b/tests/rundb/workflow.py @@ -8,8 +8,11 @@ def kfpipeline(): # analyze our dataset funcs["describe"].as_step( - name="summary", params={"label_column": "labels"}, + name="summary", + params={"label_column": "labels"}, ) # train with hyper-paremeters - funcs["trainer-function"].as_step(name="trainer-function",) + funcs["trainer-function"].as_step( + name="trainer-function", + ) diff --git a/tests/runtimes/test_function.py b/tests/runtimes/test_function.py index 788dfb0eeb..9b22ef4e6c 100644 --- a/tests/runtimes/test_function.py +++ b/tests/runtimes/test_function.py @@ -57,7 +57,14 @@ def test_generate_nuclio_volumes(): ] function = mlrun.new_function(runtime=runtime) nuclio_volumes = function.spec.generate_nuclio_volumes() - assert DeepDiff(expected_nuclio_volumes, nuclio_volumes, ignore_order=True,) == {} + assert ( + DeepDiff( + expected_nuclio_volumes, + nuclio_volumes, + ignore_order=True, + ) + == {} + ) class TestAutoMountNuclio(TestAutoMount): @@ -90,7 +97,10 @@ def _execute_run(self, runtime): def test_http_trigger(): function: mlrun.runtimes.RemoteRuntime = mlrun.new_function("tst", kind="nuclio") function.with_http( - workers=2, host="x", worker_timeout=5, extra_attributes={"yy": "123"}, + workers=2, + host="x", + worker_timeout=5, + extra_attributes={"yy": "123"}, ) trigger = function.spec.config["spec.triggers.http"] diff --git a/tests/runtimes/test_run.py b/tests/runtimes/test_run.py index 5af8581c52..07133a2452 100644 --- a/tests/runtimes/test_run.py +++ b/tests/runtimes/test_run.py @@ -37,7 +37,14 @@ def test_new_function_from_runtime(): } function = mlrun.new_function(runtime=runtime) - assert DeepDiff(runtime, function.to_dict(), ignore_order=True,) == {} + assert ( + DeepDiff( + runtime, + function.to_dict(), + ignore_order=True, + ) + == {} + ) def test_new_function_args_without_command(): @@ -72,4 +79,11 @@ def test_new_function_args_without_command(): "verbose": False, } function = mlrun.new_function(runtime=runtime) - assert DeepDiff(runtime, function.to_dict(), ignore_order=True,) == {} + assert ( + DeepDiff( + runtime, + function.to_dict(), + ignore_order=True, + ) + == {} + ) diff --git a/tests/system/backward_compatiblity/test_api_backward_compatibility.py b/tests/system/backward_compatiblity/test_api_backward_compatibility.py index 2cd62eb1e2..15edc2dfb2 100644 --- a/tests/system/backward_compatiblity/test_api_backward_compatibility.py +++ b/tests/system/backward_compatiblity/test_api_backward_compatibility.py @@ -42,5 +42,6 @@ def test_endpoints_called_by_sdk_from_inside_jobs(self): with pytest.raises(mlrun.runtimes.utils.RunError): function.run( - name=f"test_{failure_handler}", handler=failure_handler, + name=f"test_{failure_handler}", + handler=failure_handler, ) diff --git a/tests/system/demos/horovod/test_horovod.py b/tests/system/demos/horovod/test_horovod.py index 59d243dd59..ba30da5380 100644 --- a/tests/system/demos/horovod/test_horovod.py +++ b/tests/system/demos/horovod/test_horovod.py @@ -34,7 +34,10 @@ def create_demo_project(self) -> mlrun.projects.MlrunProject: self._logger.debug("Creating iris-generator function") function_path = str(self.assets_path / "utils_functions.py") utils = mlrun.code_to_function( - name="utils", kind="job", filename=function_path, image="mlrun/mlrun", + name="utils", + kind="job", + filename=function_path, + image="mlrun/mlrun", ) utils.spec.remote = True diff --git a/tests/system/demos/sklearn/test_sklearn.py b/tests/system/demos/sklearn/test_sklearn.py index ef3c77ac79..8d1eb36f6f 100644 --- a/tests/system/demos/sklearn/test_sklearn.py +++ b/tests/system/demos/sklearn/test_sklearn.py @@ -21,7 +21,10 @@ def create_demo_project(self) -> mlrun.projects.MlrunProject: self._logger.debug("Creating iris-generator function") function_path = str(self.assets_path / "iris_generator_function.py") iris_generator_function = mlrun.code_to_function( - name="gen-iris", kind="job", filename=function_path, image="mlrun/mlrun", + name="gen-iris", + kind="job", + filename=function_path, + image="mlrun/mlrun", ) iris_generator_function.spec.remote = True diff --git a/tests/system/examples/dask/test_dask.py b/tests/system/examples/dask/test_dask.py index a948fb321f..9fc8b92ddd 100644 --- a/tests/system/examples/dask/test_dask.py +++ b/tests/system/examples/dask/test_dask.py @@ -21,7 +21,9 @@ class TestDask(TestMLRunSystem): def custom_setup(self): self._logger.debug("Creating dask function") self.dask_function = code_to_function( - "mydask", kind="dask", filename=str(self.assets_path / "dask_function.py"), + "mydask", + kind="dask", + filename=str(self.assets_path / "dask_function.py"), ).apply(mount_v3io()) self.dask_function.spec.image = "mlrun/ml-models" diff --git a/tests/system/feature_store/test_feature_store.py b/tests/system/feature_store/test_feature_store.py index 4050650e5c..6a9f276a53 100644 --- a/tests/system/feature_store/test_feature_store.py +++ b/tests/system/feature_store/test_feature_store.py @@ -340,7 +340,8 @@ def test_feature_set_db(self): name = "stocks_test" stocks_set = fs.FeatureSet(name, entities=["ticker"]) fs.preview( - stocks_set, stocks, + stocks_set, + stocks, ) stocks_set.save() db = mlrun.get_run_db() @@ -534,7 +535,11 @@ def test_filtering_parquet_by_time(self): end_time="2020-12-01 17:33:16", ) - resp = fs.ingest(measurements, source, return_df=True,) + resp = fs.ingest( + measurements, + source, + return_df=True, + ) assert len(resp) == 10 # start time > timestamp in source @@ -546,7 +551,11 @@ def test_filtering_parquet_by_time(self): end_time="2022-12-01 17:33:16", ) - resp = fs.ingest(measurements, source, return_df=True,) + resp = fs.ingest( + measurements, + source, + return_df=True, + ) assert len(resp) == 0 @pytest.mark.parametrize("key_bucketing_number", [None, 0, 4]) @@ -594,7 +603,10 @@ def test_ingest_partitioned_by_key_and_time( kind = TargetTypes.parquet path = f"{get_default_prefix_for_target(kind)}/sets/{name}-latest" path = path.format(name=name, kind=kind, project=self.project_name) - dataset = pq.ParquetDataset(path, filesystem=file_system,) + dataset = pq.ParquetDataset( + path, + filesystem=file_system, + ) partitions = [key for key, _ in dataset.pieces[0].partition_keys] if key_bucketing_number is None: @@ -647,7 +659,8 @@ def test_ingest_twice_with_nulls(self): df.set_index("my_string") source = DataFrameSource(df) measurements.set_targets( - targets=[ParquetTarget(partitioned=True)], with_defaults=False, + targets=[ParquetTarget(partitioned=True)], + with_defaults=False, ) resp1 = fs.ingest(measurements, source) assert resp1.to_dict() == { @@ -676,7 +689,8 @@ def test_ingest_twice_with_nulls(self): df.set_index("my_string") source = DataFrameSource(df) measurements.set_targets( - targets=[ParquetTarget(partitioned=True)], with_defaults=False, + targets=[ParquetTarget(partitioned=True)], + with_defaults=False, ) resp1 = fs.ingest(measurements, source, overwrite=False) assert resp1.to_dict() == { @@ -800,7 +814,10 @@ def test_multiple_entities(self): ) data_set.add_aggregation( - column="bid", operations=["sum", "max"], windows="1h", period="10m", + column="bid", + operations=["sum", "max"], + windows="1h", + period="10m", ) fs.preview( data_set, @@ -1049,7 +1066,9 @@ def test_schedule_on_filtered_by_time(self, partitioned): ) feature_set = fs.FeatureSet( - name=name, entities=[fs.Entity("first_name")], timestamp_key="time", + name=name, + entities=[fs.Entity("first_name")], + timestamp_key="time", ) if partitioned: @@ -1149,7 +1168,9 @@ def test_overwrite_single_file(self): source = ParquetSource("myparquet", schedule=cron_trigger, path=path) feature_set = fs.FeatureSet( - name="overwrite", entities=[fs.Entity("first_name")], timestamp_key="time", + name="overwrite", + entities=[fs.Entity("first_name")], + timestamp_key="time", ) targets = [ParquetTarget(path="v3io:///bigdata/bla.parquet")] @@ -1196,7 +1217,10 @@ def test_query_on_fixed_window(self, fixed_window_type): ) data_set.add_aggregation( - name="bids", column="bid", operations=["sum", "max"], windows="24h", + name="bids", + column="bid", + operations=["sum", "max"], + windows="24h", ) fs.ingest(data_set, data, return_df=True) @@ -1292,7 +1316,10 @@ def test_feature_aliases(self): data_set = fs.FeatureSet("aliass", entities=[Entity("ticker")]) data_set.add_aggregation( - column="price", operations=["sum", "max"], windows="1h", period="10m", + column="price", + operations=["sum", "max"], + windows="1h", + period="10m", ) fs.ingest(data_set, df) @@ -1647,7 +1674,11 @@ def test_purge(self): key = "patient_id" fset = fs.FeatureSet("purge", entities=[Entity(key)], timestamp_key="timestamp") path = os.path.relpath(str(self.assets_path / "testdata.csv")) - source = CSVSource("mycsv", path=path, time_field="timestamp",) + source = CSVSource( + "mycsv", + path=path, + time_field="timestamp", + ) targets = [ CSVTarget(), CSVTarget(name="specified-path", path="v3io:///bigdata/csv-purge-test.csv"), @@ -1655,7 +1686,8 @@ def test_purge(self): NoSqlTarget(), ] fset.set_targets( - targets=targets, with_defaults=False, + targets=targets, + with_defaults=False, ) fs.ingest(fset, source) @@ -1687,7 +1719,11 @@ def get_v3io_api_host(): name="nosqlpurge", entities=[Entity(key)], timestamp_key="timestamp" ) path = os.path.relpath(str(self.assets_path / "testdata.csv")) - source = CSVSource("mycsv", path=path, time_field="timestamp",) + source = CSVSource( + "mycsv", + path=path, + time_field="timestamp", + ) targets = [ NoSqlTarget( name="nosql", path="v3io:///bigdata/system-test-project/nosql-purge" @@ -1701,7 +1737,8 @@ def get_v3io_api_host(): for tar in targets: test_target = [tar] fset.set_targets( - with_defaults=False, targets=test_target, + with_defaults=False, + targets=test_target, ) fs.ingest(fset, source) verify_purge(fset, test_target) @@ -1904,7 +1941,9 @@ def validate_result(test_vector, test_keys): # change feature set and save with tag test_set.add_aggregation( - "bid", ["avg"], "1h", + "bid", + ["avg"], + "1h", ) new_column = "bid_avg_1h" test_set.metadata.tag = tag @@ -1960,7 +1999,9 @@ def validate_result(test_vector, test_keys): # change feature set and save with tag test_set.add_aggregation( - "bid", ["avg"], "1h", + "bid", + ["avg"], + "1h", ) new_column = "bid_avg_1h" test_set.metadata.tag = tag @@ -2096,7 +2137,9 @@ def test_online_impute(self): "imp1", entities=[Entity("name")], timestamp_key="time_stamp" ) data_set1.add_aggregation( - "data", ["avg", "max"], "1h", + "data", + ["avg", "max"], + "1h", ) fs.ingest(data_set1, data, infer_options=fs.InferOptions.default()) diff --git a/tests/system/feature_store/test_spark_engine.py b/tests/system/feature_store/test_spark_engine.py index c24dca41a3..fe5bad9828 100644 --- a/tests/system/feature_store/test_spark_engine.py +++ b/tests/system/feature_store/test_spark_engine.py @@ -100,7 +100,9 @@ def test_error_flow(self): ) measurements = fs.FeatureSet( - "measurements", entities=[fs.Entity("name")], engine="spark", + "measurements", + entities=[fs.Entity("name")], + engine="spark", ) with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): @@ -245,11 +247,15 @@ def test_aggregations(self): source = ParquetSource("myparquet", path=path, time_field="time") data_set = fs.FeatureSet( - f"{name}_storey", entities=[Entity("first_name"), Entity("last_name")], + f"{name}_storey", + entities=[Entity("first_name"), Entity("last_name")], ) data_set.add_aggregation( - column="bid", operations=["sum", "max"], windows="1h", period="10m", + column="bid", + operations=["sum", "max"], + windows="1h", + period="10m", ) df = fs.ingest(data_set, source, targets=[]) @@ -273,7 +279,10 @@ def test_aggregations(self): ) data_set.add_aggregation( - column="bid", operations=["sum", "max"], windows="1h", period="10m", + column="bid", + operations=["sum", "max"], + windows="1h", + period="10m", ) fs.ingest( diff --git a/tests/system/model_monitoring/test_model_monitoring.py b/tests/system/model_monitoring/test_model_monitoring.py index 44fabe042b..4acb48c9c3 100644 --- a/tests/system/model_monitoring/test_model_monitoring.py +++ b/tests/system/model_monitoring/test_model_monitoring.py @@ -241,7 +241,10 @@ def test_model_monitoring_voting_ensemble(self): label_column = "label" - train_set = pd.DataFrame(iris["data"], columns=columns,) + train_set = pd.DataFrame( + iris["data"], + columns=columns, + ) train_set[label_column] = iris["target"] diff --git a/tests/system/projects/test_project.py b/tests/system/projects/test_project.py index cb7f7951a3..74f0078e34 100644 --- a/tests/system/projects/test_project.py +++ b/tests/system/projects/test_project.py @@ -268,7 +268,10 @@ def test_local_cli(self): name = "lclclipipe" project = self._create_project(name) project.set_function( - "gen_iris.py", "gen-iris", image="mlrun/mlrun", handler="iris_generator", + "gen_iris.py", + "gen-iris", + image="mlrun/mlrun", + handler="iris_generator", ) project.save() print(project.to_yaml()) diff --git a/tests/system/runtimes/test_nuclio.py b/tests/system/runtimes/test_nuclio.py index 078a7be973..66e9e8b5e5 100644 --- a/tests/system/runtimes/test_nuclio.py +++ b/tests/system/runtimes/test_nuclio.py @@ -163,7 +163,9 @@ def test_hyper_run(self): fn = self._deploy_function(2) hyper_param_options = mlrun.model.HyperParamOptions( - parallel_runs=4, selector="max.accuracy", max_errors=1, + parallel_runs=4, + selector="max.accuracy", + max_errors=1, ) p1 = [4, 2, 5, 8, 9, 6, 1, 11, 1, 1, 2, 1, 1] diff --git a/tests/test_builder.py b/tests/test_builder.py index 7736059d1c..3fecc32467 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -15,7 +15,10 @@ def test_build_runtime_use_base_image_when_no_build(): base_image = "mlrun/ml-models" fn.build_config(base_image=base_image) assert fn.spec.image == "" - ready = mlrun.builder.build_runtime(mlrun.api.schemas.AuthInfo(), fn,) + ready = mlrun.builder.build_runtime( + mlrun.api.schemas.AuthInfo(), + fn, + ) assert ready is True assert fn.spec.image == base_image @@ -27,7 +30,9 @@ def test_build_runtime_use_image_when_no_build(): ) assert fn.spec.image == image ready = mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), fn, with_mlrun=False, + mlrun.api.schemas.AuthInfo(), + fn, + with_mlrun=False, ) assert ready is True assert fn.spec.image == image @@ -94,7 +99,8 @@ def test_build_runtime_insecure_registries(monkeypatch): mlrun.mlconf.httpdb.builder.insecure_push_registry_mode = case["push_mode"] mlrun.mlconf.httpdb.builder.docker_registry_secret = case["secret"] mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) assert ( insecure_flags.issubset( @@ -130,12 +136,15 @@ def test_build_runtime_target_image(monkeypatch): kind="job", requirements=["some-package"], ) - image_name_prefix = mlrun.mlconf.httpdb.builder.function_target_image_name_prefix_template.format( - project=function.metadata.project, name=function.metadata.name + image_name_prefix = ( + mlrun.mlconf.httpdb.builder.function_target_image_name_prefix_template.format( + project=function.metadata.project, name=function.metadata.name + ) ) mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) # assert the default target image @@ -147,7 +156,8 @@ def test_build_runtime_target_image(monkeypatch): f"{registry}/{image_name_prefix}-some-addition:{function.metadata.tag}" ) mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) target_image = _get_target_image_from_create_pod_mock() assert target_image == function.spec.build.image @@ -159,7 +169,8 @@ def test_build_runtime_target_image(monkeypatch): f"/{image_name_prefix}-some-addition:{function.metadata.tag}" ) mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) target_image = _get_target_image_from_create_pod_mock() assert ( @@ -175,7 +186,8 @@ def test_build_runtime_target_image(monkeypatch): function.spec.build.image = invalid_image with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) # assert if we can not-stick to the regex if it's a different registry @@ -184,7 +196,8 @@ def test_build_runtime_target_image(monkeypatch): f":{function.metadata.tag}" ) mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), function, + mlrun.api.schemas.AuthInfo(), + function, ) target_image = _get_target_image_from_create_pod_mock() assert target_image == function.spec.build.image diff --git a/tests/test_config.py b/tests/test_config.py index 27a33f7bdc..df865215b8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -216,7 +216,8 @@ def test_setting_dbpath_trigger_connect(requests_mock: requests_mock_package.Moc "remote_host": remote_host, } requests_mock.get( - f"{api_url}/{HTTPRunDB.get_api_path_prefix()}/client-spec", json=response_body, + f"{api_url}/{HTTPRunDB.get_api_path_prefix()}/client-spec", + json=response_body, ) assert "" == mlconf.config.remote_host mlconf.config.dbpath = api_url diff --git a/tests/test_kfp.py b/tests/test_kfp.py index 6665f24238..8ee55dcd25 100644 --- a/tests/test_kfp.py +++ b/tests/test_kfp.py @@ -189,5 +189,7 @@ def _assert_metrics_file( def _generate_task(p1, out_path): return new_task( - params={"p1": p1}, out_path=out_path, outputs=["accuracy", "loss"], + params={"p1": p1}, + out_path=out_path, + outputs=["accuracy", "loss"], ).set_label("tests", "kfp") diff --git a/tests/test_requirements.py b/tests/test_requirements.py index a65622aea8..21d3edea78 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -83,8 +83,6 @@ def test_requirement_specifiers_convention(): "aiobotocore": {"~=1.4.0"}, "storey": {"~=0.8.11, <0.8.12"}, "bokeh": {"~=2.4, >=2.4.2"}, - # Black is not stable yet and does not have a release that is not beta, so can't be used with ~= - "black": {"<=19.10b0"}, # These 2 are used in a tests that is purposed to test requirement without specifiers "faker": {""}, "python-dotenv": {""}, @@ -105,7 +103,6 @@ def test_requirement_specifiers_convention(): "urllib3": {">=1.25.4, <1.27"}, "cryptography": {"~=3.0, <3.4"}, "chardet": {">=3.0.2, <4.0"}, - "google-auth": {">=1.25.0, <2.0dev"}, "numpy": {">=1.16.5, <1.22.0"}, "alembic": {"~=1.4,<1.6.0"}, "boto3": {"~=1.9, <1.17.107"}, @@ -114,7 +111,10 @@ def test_requirement_specifiers_convention(): "pyarrow": {">=1,<6"}, } - for (ignored_requirement_name, ignored_specifiers,) in ignored_invalid_map.items(): + for ( + ignored_requirement_name, + ignored_specifiers, + ) in ignored_invalid_map.items(): if ignored_requirement_name in invalid_requirement_specifiers_map: diff = deepdiff.DeepDiff( invalid_requirement_specifiers_map[ignored_requirement_name], diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 867c59546a..f508533959 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -509,7 +509,7 @@ def test_create_exponential_backoff(): max_value = 120 backoff = mlrun.utils.helpers.create_exponential_backoff(base, max_value) for i in range(1, 120): - expected_value = min(base ** i, max_value) + expected_value = min(base**i, max_value) assert expected_value, next(backoff)