From a86996bb391da0f7f9ce7c694a55a03874b9e8e5 Mon Sep 17 00:00:00 2001 From: Alon Maor <48641682+AlonMaor14@users.noreply.github.com> Date: Tue, 19 Jul 2022 09:52:09 +0300 Subject: [PATCH] [Runtimes] - Security context support (#2124) --- mlrun/api/schemas/function.py | 5 + mlrun/config.py | 10 ++ mlrun/runtimes/daskjob.py | 2 + mlrun/runtimes/function.py | 10 ++ mlrun/runtimes/mpijob/abstract.py | 2 + mlrun/runtimes/mpijob/v1.py | 10 ++ mlrun/runtimes/mpijob/v1alpha1.py | 5 + mlrun/runtimes/pod.py | 157 +++++++++++---------------- mlrun/runtimes/remotesparkjob.py | 2 + mlrun/runtimes/serving.py | 2 + mlrun/runtimes/sparkjob/abstract.py | 2 + mlrun/runtimes/sparkjob/spark3job.py | 47 ++++---- mlrun/runtimes/utils.py | 1 + tests/api/runtimes/base.py | 16 +++ tests/api/runtimes/test_dask.py | 44 ++++++++ tests/api/runtimes/test_kubejob.py | 40 +++++++ tests/api/runtimes/test_nuclio.py | 57 ++++++++++ tests/api/runtimes/test_spark.py | 24 ++++ tests/runtimes/test_pod.py | 5 +- tests/runtimes/test_run.py | 1 + 20 files changed, 326 insertions(+), 116 deletions(-) diff --git a/mlrun/api/schemas/function.py b/mlrun/api/schemas/function.py index 264fa3de60e..e4150ab17b2 100644 --- a/mlrun/api/schemas/function.py +++ b/mlrun/api/schemas/function.py @@ -46,8 +46,13 @@ class ImagePullSecret(pydantic.BaseModel): default: typing.Optional[str] +class SecurityContext(pydantic.BaseModel): + default: typing.Optional[str] + + class FunctionSpec(pydantic.BaseModel): image_pull_secret: typing.Optional[ImagePullSecret] + security_context: typing.Optional[SecurityContext] class Function(pydantic.BaseModel): diff --git a/mlrun/config.py b/mlrun/config.py index fa5930f119f..fb9d5797bea 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -132,6 +132,11 @@ "function": { "spec": { "image_pull_secret": {"default": None}, + "security_context": { + # default security context to be applied to all functions - json string base64 encoded format + # in camelCase format: {"runAsUser": 1000, "runAsGroup": 3000} + "default": "e30=", # encoded empty dict + }, }, }, "function_defaults": { @@ -512,6 +517,11 @@ def get_preemptible_tolerations(self) -> list: "preemptible_nodes.tolerations", list ) + def get_default_function_security_context(self) -> dict: + return self.decode_base64_config_and_load_to_object( + "function.spec.security_context.default", dict + ) + def is_preemption_nodes_configured(self): if ( not self.get_preemptible_tolerations() diff --git a/mlrun/runtimes/daskjob.py b/mlrun/runtimes/daskjob.py index 16fa6ad2ce1..05c0092dbb0 100644 --- a/mlrun/runtimes/daskjob.py +++ b/mlrun/runtimes/daskjob.py @@ -103,6 +103,7 @@ def __init__( workdir=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( @@ -131,6 +132,7 @@ def __init__( workdir=workdir, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.args = args diff --git a/mlrun/runtimes/function.py b/mlrun/runtimes/function.py index 66a556f265d..112a42b7d8c 100644 --- a/mlrun/runtimes/function.py +++ b/mlrun/runtimes/function.py @@ -165,6 +165,7 @@ def __init__( image_pull_secret=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( @@ -193,6 +194,7 @@ def __init__( image_pull_secret=image_pull_secret, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.base_spec = base_spec or {} @@ -1217,6 +1219,14 @@ def compile_function_config( if function.spec.service_account: nuclio_spec.set_config("spec.serviceAccount", function.spec.service_account) + if function.spec.security_context: + nuclio_spec.set_config( + "spec.securityContext", + mlrun.runtimes.pod.get_sanitized_attribute( + function.spec, "security_context" + ), + ) + if ( function.spec.base_spec or function.spec.build.functionSourceCode diff --git a/mlrun/runtimes/mpijob/abstract.py b/mlrun/runtimes/mpijob/abstract.py index e46dbe494d8..33a8381cdb3 100644 --- a/mlrun/runtimes/mpijob/abstract.py +++ b/mlrun/runtimes/mpijob/abstract.py @@ -58,6 +58,7 @@ def __init__( pythonpath=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( command=command, @@ -85,6 +86,7 @@ def __init__( pythonpath=pythonpath, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.mpi_args = mpi_args or [ "-x", diff --git a/mlrun/runtimes/mpijob/v1.py b/mlrun/runtimes/mpijob/v1.py index 8a89b1da52b..40c186f05af 100644 --- a/mlrun/runtimes/mpijob/v1.py +++ b/mlrun/runtimes/mpijob/v1.py @@ -61,6 +61,7 @@ def __init__( pythonpath=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( command=command, @@ -89,6 +90,7 @@ def __init__( pythonpath=pythonpath, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.clean_pod_policy = clean_pod_policy or MPIJobV1CleanPodPolicies.default() @@ -213,6 +215,14 @@ def _generate_mpi_job( "spec.imagePullSecrets", [{"name": self.spec.image_pull_secret}], ) + if self.spec.security_context: + update_in( + pod_template, + "spec.securityContext", + mlrun.runtimes.pod.get_sanitized_attribute( + self.spec, "security_context" + ), + ) update_in(pod_template, "metadata.labels", pod_labels) update_in(pod_template, "spec.volumes", self.spec.volumes) update_in(pod_template, "spec.nodeName", self.spec.node_name) diff --git a/mlrun/runtimes/mpijob/v1alpha1.py b/mlrun/runtimes/mpijob/v1alpha1.py index 12dfdf59c6e..1db61ac5e54 100644 --- a/mlrun/runtimes/mpijob/v1alpha1.py +++ b/mlrun/runtimes/mpijob/v1alpha1.py @@ -95,6 +95,11 @@ def _generate_mpi_job( "spec.template.spec.tolerations", mlrun.runtimes.pod.get_sanitized_attribute(self.spec, "tolerations"), ) + update_in( + job, + "spec.template.spec.securityContext", + mlrun.runtimes.pod.get_sanitized_attribute(self.spec, "security_context"), + ) if self.spec.priority_class_name and len( mlconf.get_valid_function_priority_class_names() ): diff --git a/mlrun/runtimes/pod.py b/mlrun/runtimes/pod.py index a5736836c7f..379d3d18056 100644 --- a/mlrun/runtimes/pod.py +++ b/mlrun/runtimes/pod.py @@ -33,7 +33,7 @@ generate_preemptible_tolerations, ) from ..secrets import SecretsStore -from ..utils import get_in, logger, normalize_name, update_in +from ..utils import logger, normalize_name, update_in from .base import BaseRuntime, FunctionSpec, spec_fields from .utils import ( apply_kfp, @@ -51,28 +51,34 @@ "attribute_type": k8s_client.V1Affinity, "sub_attribute_type": None, "contains_many": False, - "not_sanitized": "node_affinity", "not_sanitized_class": dict, - "sanitized": "nodeAffinity", }, "tolerations": { "attribute_type_name": "List[V1.Toleration]", "attribute_type": list, "contains_many": True, "sub_attribute_type": k8s_client.V1Toleration, - "not_sanitized": "toleration_seconds", "not_sanitized_class": list, - "sanitized": "tolerationSeconds", + }, + "security_context": { + "attribute_type_name": "V1SecurityContext", + "attribute_type": k8s_client.V1SecurityContext, + "sub_attribute_type": None, + "contains_many": False, + "not_sanitized_class": dict, }, } sanitized_attributes = { "affinity": sanitized_types["affinity"], "tolerations": sanitized_types["tolerations"], + "security_context": sanitized_types["security_context"], "executor_tolerations": sanitized_types["tolerations"], "driver_tolerations": sanitized_types["tolerations"], "executor_affinity": sanitized_types["affinity"], "driver_affinity": sanitized_types["affinity"], + "executor_security_context": sanitized_types["security_context"], + "driver_security_context": sanitized_types["security_context"], } @@ -92,6 +98,7 @@ class KubeResourceSpec(FunctionSpec): "priority_class_name", "tolerations", "preemption_mode", + "security_context", ] def __init__( @@ -121,6 +128,7 @@ def __init__( priority_class_name=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( command=command, @@ -160,6 +168,9 @@ def __init__( ) self._tolerations = tolerations self.preemption_mode = preemption_mode + self.security_context = ( + security_context or mlrun.mlconf.get_default_function_security_context() + ) @property def volumes(self) -> list: @@ -220,11 +231,24 @@ def preemption_mode(self, mode): self._preemption_mode = mode or mlconf.function_defaults.preemption_mode self.enrich_function_preemption_spec() + @property + def security_context(self) -> k8s_client.V1SecurityContext: + return self._security_context + + @security_context.setter + def security_context(self, security_context): + self._security_context = transform_attribute_to_k8s_class_instance( + "security_context", security_context + ) + def to_dict(self, fields=None, exclude=None): - struct = super().to_dict(fields, exclude=["affinity", "tolerations"]) + exclude = exclude or [] + _exclude = ["affinity", "tolerations", "security_context"] + struct = super().to_dict(fields, exclude=list(set(exclude + _exclude))) api = k8s_client.ApiClient() - struct["affinity"] = api.sanitize_for_serialization(self.affinity) - struct["tolerations"] = api.sanitize_for_serialization(self.tolerations) + for field in _exclude: + if field not in exclude: + struct[field] = api.sanitize_for_serialization(getattr(self, field)) return struct def update_vols_and_mounts( @@ -241,88 +265,6 @@ def update_vols_and_mounts( def _get_affinity_as_k8s_class_instance(self): pass - def _transform_attribute_to_k8s_class_instance( - self, attribute_name, attribute, is_sub_attr: bool = False - ): - if attribute_name not in sanitized_attributes: - raise mlrun.errors.MLRunInvalidArgumentError( - f"{attribute_name} isn't in the available sanitized attributes" - ) - attribute_config = sanitized_attributes[attribute_name] - # initialize empty attribute type - if attribute is None: - return None - if isinstance(attribute, dict): - if self._resolve_if_type_sanitized(attribute_name, attribute): - api = k8s_client.ApiClient() - # not ideal to use their private method, but looks like that's the only option - # Taken from https://github.com/kubernetes-client/python/issues/977 - attribute_type = attribute_config["attribute_type"] - if attribute_config["contains_many"]: - attribute_type = attribute_config["sub_attribute_type"] - attribute = api._ApiClient__deserialize(attribute, attribute_type) - - elif isinstance(attribute, list): - attribute_instance = [] - for sub_attr in attribute: - if not isinstance(sub_attr, dict): - return attribute - attribute_instance.append( - self._transform_attribute_to_k8s_class_instance( - attribute_name, sub_attr, is_sub_attr=True - ) - ) - attribute = attribute_instance - # if user have set one attribute but its part of an attribute that contains many then return inside a list - if ( - not is_sub_attr - and attribute_config["contains_many"] - and isinstance(attribute, attribute_config["sub_attribute_type"]) - ): - # initialize attribute instance and add attribute to it, - # mainly done when attribute is a list but user defines only sets the attribute not in the list - attribute_instance = attribute_config["attribute_type"]() - attribute_instance.append(attribute) - return attribute_instance - return attribute - - def _get_sanitized_attribute(self, attribute_name: str): - """ - When using methods like to_dict() on kubernetes class instances we're getting the attributes in snake_case - Which is ok if we're using the kubernetes python package but not if for example we're creating CRDs that we - apply directly. For that we need the sanitized (CamelCase) version. - """ - attribute = getattr(self, attribute_name) - if attribute_name not in sanitized_attributes: - raise mlrun.errors.MLRunInvalidArgumentError( - f"{attribute_name} isn't in the available sanitized attributes" - ) - attribute_config = sanitized_attributes[attribute_name] - if not attribute: - return attribute_config["not_sanitized_class"]() - - # check if attribute of type dict, and then check if type is sanitized - if isinstance(attribute, dict): - if attribute_config["not_sanitized_class"] != dict: - raise mlrun.errors.MLRunInvalidArgumentTypeError( - f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got dict" - ) - if _resolve_if_type_sanitized(attribute_name, attribute): - return attribute - - elif isinstance(attribute, list) and not isinstance( - attribute[0], attribute_config["sub_attribute_type"] - ): - if attribute_config["not_sanitized_class"] != list: - raise mlrun.errors.MLRunInvalidArgumentTypeError( - f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got list" - ) - if _resolve_if_type_sanitized(attribute_name, attribute[0]): - return attribute - - api = k8s_client.ApiClient() - return api.sanitize_for_serialization(attribute) - def _set_volume_mount( self, volume_mount, volume_mounts_field_name="_volume_mounts" ): @@ -1062,6 +1004,26 @@ def with_preemption_mode(self, mode: typing.Union[PreemptionModes, str]): preemptible_mode = PreemptionModes(mode) self.spec.preemption_mode = preemptible_mode.value + def with_security_context(self, security_context: k8s_client.V1SecurityContext): + """ + Set security context for the pod + Example: + + from kubernetes import client as k8s_client + + security_context = k8s_client.V1SecurityContext( + run_as_user=1000, + run_as_group=3000, + ) + function.with_security_context(security_context) + + More info: + https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#set-the-security-context-for-a-pod + + :param security_context: The security context for the pod + """ + self.spec.security_context = security_context + def list_valid_priority_class_names(self): return mlconf.get_valid_function_priority_class_names() @@ -1268,6 +1230,7 @@ def kube_resource_spec_to_pod_spec( if len(mlconf.get_valid_function_priority_class_names()) else None, tolerations=kube_resource_spec.tolerations, + security_context=kube_resource_spec.security_context, ) @@ -1275,13 +1238,15 @@ def _resolve_if_type_sanitized(attribute_name, attribute): attribute_config = sanitized_attributes[attribute_name] # heuristic - if one of the keys contains _ as part of the dict it means to_dict on the kubernetes # object performed, there's nothing we can do at that point to transform it to the sanitized version - if get_in(attribute, attribute_config["not_sanitized"]): - raise mlrun.errors.MLRunInvalidArgumentTypeError( - f"{attribute_name} must be instance of kubernetes {attribute_config.get('attribute_type_name')} class" - ) + for key in attribute.keys(): + if "_" in key: + raise mlrun.errors.MLRunInvalidArgumentTypeError( + f"{attribute_name} must be instance of kubernetes {attribute_config.get('attribute_type_name')} class " + f"but contains not sanitized key: {key}" + ) + # then it's already the sanitized version - elif get_in(attribute, attribute_config["sanitized"]): - return attribute + return attribute def transform_attribute_to_k8s_class_instance( diff --git a/mlrun/runtimes/remotesparkjob.py b/mlrun/runtimes/remotesparkjob.py index c855126c27b..e441a764e94 100644 --- a/mlrun/runtimes/remotesparkjob.py +++ b/mlrun/runtimes/remotesparkjob.py @@ -54,6 +54,7 @@ def __init__( pythonpath=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( command=command, @@ -81,6 +82,7 @@ def __init__( pythonpath=pythonpath, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.provider = provider diff --git a/mlrun/runtimes/serving.py b/mlrun/runtimes/serving.py index fb7f7a92517..c64278b5869 100644 --- a/mlrun/runtimes/serving.py +++ b/mlrun/runtimes/serving.py @@ -138,6 +138,7 @@ def __init__( image_pull_secret=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( @@ -174,6 +175,7 @@ def __init__( image_pull_secret=image_pull_secret, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.models = models or {} diff --git a/mlrun/runtimes/sparkjob/abstract.py b/mlrun/runtimes/sparkjob/abstract.py index 2845012fa51..3f42d0f0cd5 100644 --- a/mlrun/runtimes/sparkjob/abstract.py +++ b/mlrun/runtimes/sparkjob/abstract.py @@ -141,6 +141,7 @@ def __init__( affinity=None, tolerations=None, preemption_mode=None, + security_context=None, ): super().__init__( @@ -169,6 +170,7 @@ def __init__( affinity=affinity, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self._driver_resources = self.enrich_resources_with_default_pod_resources( diff --git a/mlrun/runtimes/sparkjob/spark3job.py b/mlrun/runtimes/sparkjob/spark3job.py index 82a8c7dfc18..23d0b5eeebc 100644 --- a/mlrun/runtimes/sparkjob/spark3job.py +++ b/mlrun/runtimes/sparkjob/spark3job.py @@ -99,6 +99,7 @@ def __init__( executor_java_options=None, driver_cores=None, executor_cores=None, + security_context=None, ): super().__init__( @@ -127,6 +128,7 @@ def __init__( affinity=affinity, tolerations=tolerations, preemption_mode=preemption_mode, + security_context=security_context, ) self.driver_resources = driver_resources or {} @@ -160,26 +162,21 @@ def __init__( self.executor_cores = executor_cores def to_dict(self, fields=None, exclude=None): - struct = super().to_dict( - fields, - exclude=[ - "executor_affinity", - "executor_tolerations", - "driver_affinity", - "driver_tolerations", - ], - ) + exclude = exclude or [] + _exclude = [ + "affinity", + "tolerations", + "security_context", + "executor_affinity", + "executor_tolerations", + "driver_affinity", + "driver_tolerations", + ] + struct = super().to_dict(fields, exclude=list(set(exclude + _exclude))) api = kubernetes.client.ApiClient() - struct["executor_affinity"] = api.sanitize_for_serialization( - self.executor_affinity - ) - struct["driver_affinity"] = api.sanitize_for_serialization(self.driver_affinity) - struct["executor_tolerations"] = api.sanitize_for_serialization( - self.executor_tolerations - ) - struct["driver_tolerations"] = api.sanitize_for_serialization( - self.driver_tolerations - ) + for field in _exclude: + if field not in exclude: + struct[field] = api.sanitize_for_serialization(getattr(self, field)) return struct @property @@ -577,6 +574,18 @@ def with_executor_preemption_mode( preemption_mode = mlrun.api.schemas.function.PreemptionModes(mode) self.spec.executor_preemption_mode = preemption_mode.value + def with_security_context( + self, security_context: kubernetes.client.V1SecurityContext + ): + """ + With security context is not supported for spark runtime. + Driver / Executor processes run with uid / gid 1000 as long as security context is not defined. + If in the future we want to support setting security context it will work only from spark version 3.2 onwards. + """ + raise mlrun.errors.MLRunInvalidArgumentTypeError( + "with_security_context is not supported with spark operator" + ) + def with_driver_host_path_volume( self, host_path: str, diff --git a/mlrun/runtimes/utils.py b/mlrun/runtimes/utils.py index d13ae40af2b..2519508a003 100644 --- a/mlrun/runtimes/utils.py +++ b/mlrun/runtimes/utils.py @@ -604,6 +604,7 @@ def enrich_function_from_dict(function, function_dict): "credentials", "tolerations", "preemption_mode", + "security_context", ]: if attribute == "credentials": override_value = getattr(override_function.metadata, attribute, None) diff --git a/tests/api/runtimes/base.py b/tests/api/runtimes/base.py index 8822a090ce6..7fa06a4293f 100644 --- a/tests/api/runtimes/base.py +++ b/tests/api/runtimes/base.py @@ -280,6 +280,16 @@ def _generate_affinity(self) -> k8s_client.V1Affinity: ), ) + def _generate_security_context( + self, + run_as_user: typing.Optional[int] = None, + run_as_group: typing.Optional[int] = None, + ) -> k8s_client.V1SecurityContext: + return k8s_client.V1SecurityContext( + run_as_user=run_as_user, + run_as_group=run_as_group, + ) + def _mock_create_namespaced_pod(self): def _generate_pod(namespace, pod): terminated_container_state = client.V1ContainerStateTerminated( @@ -772,6 +782,12 @@ def assert_node_selection( ): pass + def assert_security_context( + self, + security_context=None, + ): + pass + def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations( self, ): diff --git a/tests/api/runtimes/test_dask.py b/tests/api/runtimes/test_dask.py index a8a8ff7c3d8..336a7b2e44d 100644 --- a/tests/api/runtimes/test_dask.py +++ b/tests/api/runtimes/test_dask.py @@ -111,6 +111,13 @@ def _assert_pods_resources( expected_scheduler_requests, ) + def assert_security_context( + self, + security_context=None, + ): + pod = self._get_pod_creation_args() + assert pod.spec.security_context == (security_context or {}) + def test_dask_runtime(self, db: Session, client: TestClient): runtime: mlrun.runtimes.DaskCluster = self._generate_runtime() @@ -313,3 +320,40 @@ def test_dask_with_default_node_selector(self, db: Session, client: TestClient): assert_namespace_env_variable=False, expected_node_selector=node_selector, ) + + def test_dask_with_default_security_context(self, db: Session, client: TestClient): + runtime = self._generate_runtime() + + _ = runtime.client + self.kube_cluster_mock.assert_called_once() + self.assert_security_context() + + default_security_context_dict = { + "runAsUser": 1000, + "runAsGroup": 3000, + } + default_security_context = self._generate_security_context( + default_security_context_dict["runAsUser"], + default_security_context_dict["runAsGroup"], + ) + + mlrun.mlconf.function.spec.security_context.default = base64.b64encode( + json.dumps(default_security_context_dict).encode("utf-8") + ) + runtime = self._generate_runtime() + + _ = runtime.client + assert self.kube_cluster_mock.call_count == 2 + self.assert_security_context(default_security_context) + + def test_dask_with_security_context(self, db: Session, client: TestClient): + runtime = self._generate_runtime() + other_security_context = self._generate_security_context( + 2000, + 2000, + ) + + # override security context + runtime.with_security_context(other_security_context) + _ = runtime.client + self.assert_security_context(other_security_context) diff --git a/tests/api/runtimes/test_kubejob.py b/tests/api/runtimes/test_kubejob.py index a3d37e799a3..e5ece629344 100644 --- a/tests/api/runtimes/test_kubejob.py +++ b/tests/api/runtimes/test_kubejob.py @@ -194,6 +194,13 @@ def assert_node_selection( else: assert pod.spec.tolerations is None + def assert_security_context( + self, + security_context=None, + ): + pod = self._get_pod_creation_args() + assert pod.spec.security_context == (security_context or {}) + def test_run_with_priority_class_name(self, db: Session, client: TestClient): runtime = self._generate_runtime() @@ -223,6 +230,39 @@ def test_run_with_priority_class_name(self, db: Session, client: TestClient): with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): runtime.with_priority_class(medium_priority_class_name) + def test_run_with_security_context(self, db: Session, client: TestClient): + runtime = self._generate_runtime() + + self.execute_function(runtime) + self.assert_security_context() + + default_security_context_dict = { + "runAsUser": 1000, + "runAsGroup": 3000, + } + default_security_context = self._generate_security_context( + default_security_context_dict["runAsUser"], + default_security_context_dict["runAsGroup"], + ) + + mlrun.mlconf.function.spec.security_context.default = base64.b64encode( + json.dumps(default_security_context_dict).encode("utf-8") + ) + runtime = self._generate_runtime() + + self.execute_function(runtime) + self.assert_security_context(default_security_context) + + # override default + other_security_context = self._generate_security_context( + run_as_group=2000, + ) + runtime = self._generate_runtime() + + runtime.with_security_context(other_security_context) + self.execute_function(runtime) + self.assert_security_context(other_security_context) + def test_run_with_mounts(self, db: Session, client: TestClient): runtime = self._generate_runtime() diff --git a/tests/api/runtimes/test_nuclio.py b/tests/api/runtimes/test_nuclio.py index f255e391a18..ff70ddbbbd4 100644 --- a/tests/api/runtimes/test_nuclio.py +++ b/tests/api/runtimes/test_nuclio.py @@ -321,6 +321,23 @@ def assert_node_selection( else: assert deploy_spec.get("tolerations") is None + def assert_security_context( + self, + security_context=None, + ): + args, _ = nuclio.deploy.deploy_config.call_args + deploy_spec = args[0]["spec"] + + if security_context: + assert ( + mlrun.runtimes.pod.transform_attribute_to_k8s_class_instance( + "security_context", deploy_spec.get("securityContext") + ) + == security_context + ) + else: + assert deploy_spec.get("securityContext") is None + def test_enrich_with_ingress_no_overriding(self, db: Session, client: TestClient): """ Expect no ingress template to be created, thought its mode is "always", @@ -1014,6 +1031,46 @@ def test_preemption_mode_with_preemptible_node_selector_without_preemptible_tole ): self.assert_run_preemption_mode_with_preemptible_node_selector_without_preemptible_tolerations_with_extra_settings() # noqa: E501 + def test_deploy_with_security_context(self, db: Session, client: TestClient): + function = self._generate_runtime(self.runtime_kind) + + self.execute_function(function) + self._assert_deploy_called_basic_config(expected_class=self.class_name) + self.assert_security_context() + + default_security_context_dict = { + "runAsUser": 1000, + "runAsGroup": 3000, + } + mlrun.mlconf.function.spec.security_context.default = base64.b64encode( + json.dumps(default_security_context_dict).encode("utf-8") + ) + default_security_context = self._generate_security_context( + default_security_context_dict["runAsUser"], + default_security_context_dict["runAsGroup"], + ) + function = self._generate_runtime(self.runtime_kind) + self.execute_function(function) + + self._assert_deploy_called_basic_config( + call_count=2, expected_class=self.class_name + ) + self.assert_security_context(default_security_context) + + function = self._generate_runtime(self.runtime_kind) + other_security_context = self._generate_security_context( + 2000, + 2000, + ) + + function.with_security_context(other_security_context) + self.execute_function(function) + + self._assert_deploy_called_basic_config( + call_count=3, expected_class=self.class_name + ) + self.assert_security_context(other_security_context) + # Kind of "nuclio:mlrun" is a special case of nuclio functions. Run the same suite of tests here as well class TestNuclioMLRunRuntime(TestNuclioRuntime): diff --git a/tests/api/runtimes/test_spark.py b/tests/api/runtimes/test_spark.py index 718b12d7ee0..c4c585358de 100644 --- a/tests/api/runtimes/test_spark.py +++ b/tests/api/runtimes/test_spark.py @@ -156,6 +156,30 @@ def _assert_limits(actual, expected): assert actual["gpu"]["name"] == expected["gpu_type"] assert actual["gpu"]["quantity"] == expected["gpus"] + def _assert_security_context( + self, + expected_driver_security_context=None, + expected_executor_security_context=None, + ): + + body = self._get_custom_object_creation_body() + + if expected_driver_security_context: + assert ( + body["spec"]["driver"].get("securityContext") + == expected_driver_security_context + ) + else: + assert body["spec"]["driver"].get("securityContext") is None + + if expected_executor_security_context: + assert ( + body["spec"]["executor"].get("securityContext") + == expected_executor_security_context + ) + else: + assert body["spec"]["executor"].get("securityContext") is None + def _sanitize_list_for_serialization(self, list_: list): kubernetes_api_client = kubernetes.client.ApiClient() return list(map(kubernetes_api_client.sanitize_for_serialization, list_)) diff --git a/tests/runtimes/test_pod.py b/tests/runtimes/test_pod.py index 643221d050a..83f8afa8445 100644 --- a/tests/runtimes/test_pod.py +++ b/tests/runtimes/test_pod.py @@ -33,6 +33,9 @@ def test_runtimes_inheritance(): mlrun.runtimes.sparkjob.spark2job.Spark2JobSpec, mlrun.runtimes.sparkjob.spark3job.Spark3JobSpec, ], + mlrun.runtimes.function.NuclioSpec: [ + mlrun.runtimes.serving.ServingSpec, + ], mlrun.runtimes.base.FunctionStatus: [ mlrun.runtimes.daskjob.DaskStatus, mlrun.runtimes.function.NuclioStatus, @@ -51,9 +54,9 @@ def test_runtimes_inheritance(): mlrun.runtimes.sparkjob.spark3job.Spark3Runtime, ], } - checked_classes = set() invalid_classes = {} for base_class, inheriting_classes in classes_map.items(): + checked_classes = set() for inheriting_class in inheriting_classes: for class_ in inspect.getmro(inheriting_class): if base_class == class_: diff --git a/tests/runtimes/test_run.py b/tests/runtimes/test_run.py index ea47c979287..945b6c8ce6f 100644 --- a/tests/runtimes/test_run.py +++ b/tests/runtimes/test_run.py @@ -35,6 +35,7 @@ def _get_runtime(): "disable_auto_mount": False, "priority_class_name": "", "tolerations": None, + "security_context": None, }, "verbose": False, }