Skip to content

Commit

Permalink
[Runtimes] - Security context support (mlrun#2124)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr authored Jul 19, 2022
1 parent 640f925 commit a86996b
Show file tree
Hide file tree
Showing 20 changed files with 326 additions and 116 deletions.
5 changes: 5 additions & 0 deletions mlrun/api/schemas/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ class ImagePullSecret(pydantic.BaseModel):
default: typing.Optional[str]


class SecurityContext(pydantic.BaseModel):
default: typing.Optional[str]


class FunctionSpec(pydantic.BaseModel):
image_pull_secret: typing.Optional[ImagePullSecret]
security_context: typing.Optional[SecurityContext]


class Function(pydantic.BaseModel):
Expand Down
10 changes: 10 additions & 0 deletions mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@
"function": {
"spec": {
"image_pull_secret": {"default": None},
"security_context": {
# default security context to be applied to all functions - json string base64 encoded format
# in camelCase format: {"runAsUser": 1000, "runAsGroup": 3000}
"default": "e30=", # encoded empty dict
},
},
},
"function_defaults": {
Expand Down Expand Up @@ -512,6 +517,11 @@ def get_preemptible_tolerations(self) -> list:
"preemptible_nodes.tolerations", list
)

def get_default_function_security_context(self) -> dict:
return self.decode_base64_config_and_load_to_object(
"function.spec.security_context.default", dict
)

def is_preemption_nodes_configured(self):
if (
not self.get_preemptible_tolerations()
Expand Down
2 changes: 2 additions & 0 deletions mlrun/runtimes/daskjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
workdir=None,
tolerations=None,
preemption_mode=None,
security_context=None,
):

super().__init__(
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
workdir=workdir,
tolerations=tolerations,
preemption_mode=preemption_mode,
security_context=security_context,
)
self.args = args

Expand Down
10 changes: 10 additions & 0 deletions mlrun/runtimes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
image_pull_secret=None,
tolerations=None,
preemption_mode=None,
security_context=None,
):

super().__init__(
Expand Down Expand Up @@ -193,6 +194,7 @@ def __init__(
image_pull_secret=image_pull_secret,
tolerations=tolerations,
preemption_mode=preemption_mode,
security_context=security_context,
)

self.base_spec = base_spec or {}
Expand Down Expand Up @@ -1217,6 +1219,14 @@ def compile_function_config(
if function.spec.service_account:
nuclio_spec.set_config("spec.serviceAccount", function.spec.service_account)

if function.spec.security_context:
nuclio_spec.set_config(
"spec.securityContext",
mlrun.runtimes.pod.get_sanitized_attribute(
function.spec, "security_context"
),
)

if (
function.spec.base_spec
or function.spec.build.functionSourceCode
Expand Down
2 changes: 2 additions & 0 deletions mlrun/runtimes/mpijob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
pythonpath=None,
tolerations=None,
preemption_mode=None,
security_context=None,
):
super().__init__(
command=command,
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
pythonpath=pythonpath,
tolerations=tolerations,
preemption_mode=preemption_mode,
security_context=security_context,
)
self.mpi_args = mpi_args or [
"-x",
Expand Down
10 changes: 10 additions & 0 deletions mlrun/runtimes/mpijob/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
pythonpath=None,
tolerations=None,
preemption_mode=None,
security_context=None,
):
super().__init__(
command=command,
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
pythonpath=pythonpath,
tolerations=tolerations,
preemption_mode=preemption_mode,
security_context=security_context,
)
self.clean_pod_policy = clean_pod_policy or MPIJobV1CleanPodPolicies.default()

Expand Down Expand Up @@ -213,6 +215,14 @@ def _generate_mpi_job(
"spec.imagePullSecrets",
[{"name": self.spec.image_pull_secret}],
)
if self.spec.security_context:
update_in(
pod_template,
"spec.securityContext",
mlrun.runtimes.pod.get_sanitized_attribute(
self.spec, "security_context"
),
)
update_in(pod_template, "metadata.labels", pod_labels)
update_in(pod_template, "spec.volumes", self.spec.volumes)
update_in(pod_template, "spec.nodeName", self.spec.node_name)
Expand Down
5 changes: 5 additions & 0 deletions mlrun/runtimes/mpijob/v1alpha1.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def _generate_mpi_job(
"spec.template.spec.tolerations",
mlrun.runtimes.pod.get_sanitized_attribute(self.spec, "tolerations"),
)
update_in(
job,
"spec.template.spec.securityContext",
mlrun.runtimes.pod.get_sanitized_attribute(self.spec, "security_context"),
)
if self.spec.priority_class_name and len(
mlconf.get_valid_function_priority_class_names()
):
Expand Down
157 changes: 61 additions & 96 deletions mlrun/runtimes/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
generate_preemptible_tolerations,
)
from ..secrets import SecretsStore
from ..utils import get_in, logger, normalize_name, update_in
from ..utils import logger, normalize_name, update_in
from .base import BaseRuntime, FunctionSpec, spec_fields
from .utils import (
apply_kfp,
Expand All @@ -51,28 +51,34 @@
"attribute_type": k8s_client.V1Affinity,
"sub_attribute_type": None,
"contains_many": False,
"not_sanitized": "node_affinity",
"not_sanitized_class": dict,
"sanitized": "nodeAffinity",
},
"tolerations": {
"attribute_type_name": "List[V1.Toleration]",
"attribute_type": list,
"contains_many": True,
"sub_attribute_type": k8s_client.V1Toleration,
"not_sanitized": "toleration_seconds",
"not_sanitized_class": list,
"sanitized": "tolerationSeconds",
},
"security_context": {
"attribute_type_name": "V1SecurityContext",
"attribute_type": k8s_client.V1SecurityContext,
"sub_attribute_type": None,
"contains_many": False,
"not_sanitized_class": dict,
},
}

sanitized_attributes = {
"affinity": sanitized_types["affinity"],
"tolerations": sanitized_types["tolerations"],
"security_context": sanitized_types["security_context"],
"executor_tolerations": sanitized_types["tolerations"],
"driver_tolerations": sanitized_types["tolerations"],
"executor_affinity": sanitized_types["affinity"],
"driver_affinity": sanitized_types["affinity"],
"executor_security_context": sanitized_types["security_context"],
"driver_security_context": sanitized_types["security_context"],
}


Expand All @@ -92,6 +98,7 @@ class KubeResourceSpec(FunctionSpec):
"priority_class_name",
"tolerations",
"preemption_mode",
"security_context",
]

def __init__(
Expand Down Expand Up @@ -121,6 +128,7 @@ def __init__(
priority_class_name=None,
tolerations=None,
preemption_mode=None,
security_context=None,
):
super().__init__(
command=command,
Expand Down Expand Up @@ -160,6 +168,9 @@ def __init__(
)
self._tolerations = tolerations
self.preemption_mode = preemption_mode
self.security_context = (
security_context or mlrun.mlconf.get_default_function_security_context()
)

@property
def volumes(self) -> list:
Expand Down Expand Up @@ -220,11 +231,24 @@ def preemption_mode(self, mode):
self._preemption_mode = mode or mlconf.function_defaults.preemption_mode
self.enrich_function_preemption_spec()

@property
def security_context(self) -> k8s_client.V1SecurityContext:
return self._security_context

@security_context.setter
def security_context(self, security_context):
self._security_context = transform_attribute_to_k8s_class_instance(
"security_context", security_context
)

def to_dict(self, fields=None, exclude=None):
struct = super().to_dict(fields, exclude=["affinity", "tolerations"])
exclude = exclude or []
_exclude = ["affinity", "tolerations", "security_context"]
struct = super().to_dict(fields, exclude=list(set(exclude + _exclude)))
api = k8s_client.ApiClient()
struct["affinity"] = api.sanitize_for_serialization(self.affinity)
struct["tolerations"] = api.sanitize_for_serialization(self.tolerations)
for field in _exclude:
if field not in exclude:
struct[field] = api.sanitize_for_serialization(getattr(self, field))
return struct

def update_vols_and_mounts(
Expand All @@ -241,88 +265,6 @@ def update_vols_and_mounts(
def _get_affinity_as_k8s_class_instance(self):
pass

def _transform_attribute_to_k8s_class_instance(
self, attribute_name, attribute, is_sub_attr: bool = False
):
if attribute_name not in sanitized_attributes:
raise mlrun.errors.MLRunInvalidArgumentError(
f"{attribute_name} isn't in the available sanitized attributes"
)
attribute_config = sanitized_attributes[attribute_name]
# initialize empty attribute type
if attribute is None:
return None
if isinstance(attribute, dict):
if self._resolve_if_type_sanitized(attribute_name, attribute):
api = k8s_client.ApiClient()
# not ideal to use their private method, but looks like that's the only option
# Taken from https://github.com/kubernetes-client/python/issues/977
attribute_type = attribute_config["attribute_type"]
if attribute_config["contains_many"]:
attribute_type = attribute_config["sub_attribute_type"]
attribute = api._ApiClient__deserialize(attribute, attribute_type)

elif isinstance(attribute, list):
attribute_instance = []
for sub_attr in attribute:
if not isinstance(sub_attr, dict):
return attribute
attribute_instance.append(
self._transform_attribute_to_k8s_class_instance(
attribute_name, sub_attr, is_sub_attr=True
)
)
attribute = attribute_instance
# if user have set one attribute but its part of an attribute that contains many then return inside a list
if (
not is_sub_attr
and attribute_config["contains_many"]
and isinstance(attribute, attribute_config["sub_attribute_type"])
):
# initialize attribute instance and add attribute to it,
# mainly done when attribute is a list but user defines only sets the attribute not in the list
attribute_instance = attribute_config["attribute_type"]()
attribute_instance.append(attribute)
return attribute_instance
return attribute

def _get_sanitized_attribute(self, attribute_name: str):
"""
When using methods like to_dict() on kubernetes class instances we're getting the attributes in snake_case
Which is ok if we're using the kubernetes python package but not if for example we're creating CRDs that we
apply directly. For that we need the sanitized (CamelCase) version.
"""
attribute = getattr(self, attribute_name)
if attribute_name not in sanitized_attributes:
raise mlrun.errors.MLRunInvalidArgumentError(
f"{attribute_name} isn't in the available sanitized attributes"
)
attribute_config = sanitized_attributes[attribute_name]
if not attribute:
return attribute_config["not_sanitized_class"]()

# check if attribute of type dict, and then check if type is sanitized
if isinstance(attribute, dict):
if attribute_config["not_sanitized_class"] != dict:
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got dict"
)
if _resolve_if_type_sanitized(attribute_name, attribute):
return attribute

elif isinstance(attribute, list) and not isinstance(
attribute[0], attribute_config["sub_attribute_type"]
):
if attribute_config["not_sanitized_class"] != list:
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got list"
)
if _resolve_if_type_sanitized(attribute_name, attribute[0]):
return attribute

api = k8s_client.ApiClient()
return api.sanitize_for_serialization(attribute)

def _set_volume_mount(
self, volume_mount, volume_mounts_field_name="_volume_mounts"
):
Expand Down Expand Up @@ -1062,6 +1004,26 @@ def with_preemption_mode(self, mode: typing.Union[PreemptionModes, str]):
preemptible_mode = PreemptionModes(mode)
self.spec.preemption_mode = preemptible_mode.value

def with_security_context(self, security_context: k8s_client.V1SecurityContext):
"""
Set security context for the pod
Example:
from kubernetes import client as k8s_client
security_context = k8s_client.V1SecurityContext(
run_as_user=1000,
run_as_group=3000,
)
function.with_security_context(security_context)
More info:
https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#set-the-security-context-for-a-pod
:param security_context: The security context for the pod
"""
self.spec.security_context = security_context

def list_valid_priority_class_names(self):
return mlconf.get_valid_function_priority_class_names()

Expand Down Expand Up @@ -1268,20 +1230,23 @@ def kube_resource_spec_to_pod_spec(
if len(mlconf.get_valid_function_priority_class_names())
else None,
tolerations=kube_resource_spec.tolerations,
security_context=kube_resource_spec.security_context,
)


def _resolve_if_type_sanitized(attribute_name, attribute):
attribute_config = sanitized_attributes[attribute_name]
# heuristic - if one of the keys contains _ as part of the dict it means to_dict on the kubernetes
# object performed, there's nothing we can do at that point to transform it to the sanitized version
if get_in(attribute, attribute_config["not_sanitized"]):
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"{attribute_name} must be instance of kubernetes {attribute_config.get('attribute_type_name')} class"
)
for key in attribute.keys():
if "_" in key:
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"{attribute_name} must be instance of kubernetes {attribute_config.get('attribute_type_name')} class "
f"but contains not sanitized key: {key}"
)

# then it's already the sanitized version
elif get_in(attribute, attribute_config["sanitized"]):
return attribute
return attribute


def transform_attribute_to_k8s_class_instance(
Expand Down
Loading

0 comments on commit a86996b

Please sign in to comment.