From cc1595317d0bdca10606c7f6adad1b92b389a3f9 Mon Sep 17 00:00:00 2001 From: MASisserson <66091016+MASisserson@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:20:36 -0400 Subject: [PATCH] Enhance stack validation (#148) * Added spec validations in Stack and Component pydantic models. Added check for mismatch between stack and component provider to yaml_utils.py * Working out testing changes * Changed tests and added test_utils * Made changes according to formatter and linter. * Made final changes for pull request. Got rid of comments and print calls. * Removed a comment in yaml_utils.py --------- Co-authored-by: Alex Strick van Linschoten --- src/mlstacks/constants.py | 62 +++++++++++++++++++ src/mlstacks/enums.py | 20 ++++++ src/mlstacks/models/component.py | 77 +++++++++++++++++++++--- src/mlstacks/models/stack.py | 8 ++- src/mlstacks/utils/model_utils.py | 46 +++++++++++++- src/mlstacks/utils/yaml_utils.py | 25 +++++++- tests/test_utils.py | 34 +++++++++++ tests/unit/models/test_component.py | 45 ++++++++++++-- tests/unit/utils/test_terraform_utils.py | 21 +++++-- tests/unit/utils/test_zenml_utils.py | 5 +- 10 files changed, 319 insertions(+), 24 deletions(-) create mode 100644 tests/test_utils.py diff --git a/src/mlstacks/constants.py b/src/mlstacks/constants.py index 1216c784..3b88745b 100644 --- a/src/mlstacks/constants.py +++ b/src/mlstacks/constants.py @@ -12,6 +12,8 @@ # permissions and limitations under the License. """MLStacks constants.""" +from typing import Dict, List + MLSTACKS_PACKAGE_NAME = "mlstacks" MLSTACKS_INITIALIZATION_FILE_FLAG = "IGNORE_ME" MLSTACKS_STACK_COMPONENT_FLAGS = [ @@ -39,6 +41,52 @@ "model_deployer": ["seldon"], "step_operator": ["sagemaker", "vertex"], } +ALLOWED_COMPONENT_TYPES: Dict[str, Dict[str, List[str]]] = { + "aws": { + "artifact_store": ["s3"], + "container_registry": ["aws"], + "experiment_tracker": ["mlflow"], + "orchestrator": [ + "kubeflow", + "kubernetes", + "sagemaker", + "skypilot", + "tekton", + ], + "mlops_platform": ["zenml"], + "model_deployer": ["seldon"], + "step_operator": ["sagemaker"], + }, + "azure": {}, + "gcp": { + "artifact_store": ["gcp"], + "container_registry": ["gcp"], + "experiment_tracker": ["mlflow"], + "orchestrator": [ + "kubeflow", + "kubernetes", + "skypilot", + "tekton", + "vertex", + ], + "mlops_platform": ["zenml"], + "model_deployer": ["seldon"], + "step_operator": ["vertex"], + }, + "k3d": { + "artifact_store": ["minio"], + "container_registry": ["default"], + "experiment_tracker": ["mlflow"], + "orchestrator": [ + "kubeflow", + "kubernetes", + "sagemaker", + "tekton", + ], + "mlops_platform": ["zenml"], + "model_deployer": ["seldon"], + }, +} PERMITTED_NAME_REGEX = r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$" ANALYTICS_OPT_IN_ENV_VARIABLE = "MLSTACKS_ANALYTICS_OPT_IN" @@ -49,5 +97,19 @@ "contain alphanumeric characters, underscores, and hyphens " "thereafter." ) +INVALID_COMPONENT_TYPE_ERROR_MESSAGE = ( + "Artifact Store, Container Registry, Experiment Tracker, Orchestrator, " + "MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d " + "providers. Step Operator may only be used with aws and gcp." +) +INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE = ( + "Only certain flavors are allowed for a given provider-component type " + "combination. For more information, consult the tables for your specified " + "provider at the MLStacks documentation: " + "https://mlstacks.zenml.io/stacks/stack-specification." +) +STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE = ( + "Stack provider and component provider mismatch." +) DEFAULT_REMOTE_STATE_BUCKET_NAME = "zenml-mlstacks-remote-state" TERRAFORM_CONFIG_BUCKET_REPLACEMENT_STRING = "BUCKETNAMEREPLACEME" diff --git a/src/mlstacks/enums.py b/src/mlstacks/enums.py index 45bbada8..122e2806 100644 --- a/src/mlstacks/enums.py +++ b/src/mlstacks/enums.py @@ -49,6 +49,7 @@ class ComponentFlavorEnum(str, Enum): TEKTON = "tekton" VERTEX = "vertex" ZENML = "zenml" + DEFAULT = "default" class DeploymentMethodEnum(str, Enum): @@ -77,3 +78,22 @@ class AnalyticsEventsEnum(str, Enum): MLSTACKS_SOURCE = "MLStacks Source" MLSTACKS_EXCEPTION = "MLStacks Exception" MLSTACKS_VERSION = "MLStacks Version" + + +class SpecTypeEnum(str, Enum): + """Spec type enum.""" + + STACK = "stack" + COMPONENT = "component" + + +class StackSpecVersionEnum(int, Enum): + """Spec version enum.""" + + ONE = 1 + + +class ComponentSpecVersionEnum(int, Enum): + """Spec version enum.""" + + ONE = 1 diff --git a/src/mlstacks/models/component.py b/src/mlstacks/models/component.py index 31a1db4a..4eed4395 100644 --- a/src/mlstacks/models/component.py +++ b/src/mlstacks/models/component.py @@ -12,17 +12,27 @@ # permissions and limitations under the License. """Component model.""" -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, validator -from mlstacks.constants import INVALID_NAME_ERROR_MESSAGE +from mlstacks.constants import ( + INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE, + INVALID_COMPONENT_TYPE_ERROR_MESSAGE, + INVALID_NAME_ERROR_MESSAGE, +) from mlstacks.enums import ( ComponentFlavorEnum, + ComponentSpecVersionEnum, ComponentTypeEnum, ProviderEnum, + SpecTypeEnum, +) +from mlstacks.utils.model_utils import ( + is_valid_component_flavor, + is_valid_component_type, + is_valid_name, ) -from mlstacks.utils.model_utils import is_valid_name class ComponentMetadata(BaseModel): @@ -49,16 +59,16 @@ class Component(BaseModel): metadata: The metadata of the component. """ - spec_version: int = 1 - spec_type: str = "component" + spec_version: ComponentSpecVersionEnum = ComponentSpecVersionEnum.ONE + spec_type: SpecTypeEnum = SpecTypeEnum.COMPONENT name: str + provider: ProviderEnum component_type: ComponentTypeEnum component_flavor: ComponentFlavorEnum - provider: ProviderEnum metadata: Optional[ComponentMetadata] = None @validator("name") - def validate_name(cls, name: str) -> str: # noqa: N805 + def validate_name(cls, name: str) -> str: # noqa """Validate the name. Name must start with an alphanumeric character and can only contain @@ -78,3 +88,56 @@ def validate_name(cls, name: str) -> str: # noqa: N805 if not is_valid_name(name): raise ValueError(INVALID_NAME_ERROR_MESSAGE) return name + + @validator("component_type") + def validate_component_type( + cls, # noqa + component_type: str, + values: Dict[str, Any], + ) -> str: + """Validate the component type. + + Artifact Store, Container Registry, Experiment Tracker, Orchestrator, + MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d + providers. Step Operator may only be used with aws and gcp. + + Args: + component_type: The component type. + values: The previously validated component specs. + + Returns: + The validated component type. + + Raises: + ValueError: If the component type is invalid. + """ + if not is_valid_component_type(component_type, values["provider"]): + raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE) + return component_type + + @validator("component_flavor") + def validate_component_flavor( + cls, # noqa + component_flavor: str, + values: Dict[str, Any], + ) -> str: + """Validate the component flavor. + + Only certain flavors are allowed for a given provider-component + type combination. For more information, consult the tables for + your specified provider at the MLStacks documentation: + https://mlstacks.zenml.io/stacks/stack-specification. + + Args: + component_flavor: The component flavor. + values: The previously validated component specs. + + Returns: + The validated component flavor. + + Raises: + ValueError: If the component flavor is invalid. + """ + if not is_valid_component_flavor(component_flavor, values): + raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE) + return component_flavor diff --git a/src/mlstacks/models/stack.py b/src/mlstacks/models/stack.py index 1afebce5..b32bdba4 100644 --- a/src/mlstacks/models/stack.py +++ b/src/mlstacks/models/stack.py @@ -19,6 +19,8 @@ from mlstacks.enums import ( DeploymentMethodEnum, ProviderEnum, + SpecTypeEnum, + StackSpecVersionEnum, ) from mlstacks.models.component import Component from mlstacks.utils.model_utils import is_valid_name @@ -38,8 +40,8 @@ class Stack(BaseModel): components: The components of the stack. """ - spec_version: int = 1 - spec_type: str = "stack" + spec_version: StackSpecVersionEnum = StackSpecVersionEnum.ONE + spec_type: SpecTypeEnum = SpecTypeEnum.STACK name: str provider: ProviderEnum default_region: Optional[str] @@ -50,7 +52,7 @@ class Stack(BaseModel): components: List[Component] = [] @validator("name") - def validate_name(cls, name: str) -> str: # noqa: N805 + def validate_name(cls, name: str) -> str: # noqa """Validate the name. Name must start with an alphanumeric character and can only contain diff --git a/src/mlstacks/utils/model_utils.py b/src/mlstacks/utils/model_utils.py index 286382e3..e42c23d5 100644 --- a/src/mlstacks/utils/model_utils.py +++ b/src/mlstacks/utils/model_utils.py @@ -13,8 +13,9 @@ """Util functions for Pydantic models and validation.""" import re +from typing import Any, Dict -from mlstacks.constants import PERMITTED_NAME_REGEX +from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX def is_valid_name(name: str) -> bool: @@ -29,3 +30,46 @@ def is_valid_name(name: str) -> bool: True if the name is valid, False otherwise. """ return re.match(PERMITTED_NAME_REGEX, name) is not None + + +def is_valid_component_type(component_type: str, provider: str) -> bool: + """Check if the component type is valid. + + Used for components. + + Args: + component_type: The component type. + provider: The provider. + + Returns: + True if the component type is valid, False otherwise. + """ + allowed_types = list(ALLOWED_COMPONENT_TYPES[provider].keys()) + return component_type in allowed_types + + +def is_valid_component_flavor( + component_flavor: str, specs: Dict[str, Any] +) -> bool: + """Check if the component flavor is valid. + + Used for components. + + Args: + component_flavor: The component flavor. + specs: The previously validated component specs. + + Returns: + True if the component flavor is valid, False otherwise. + """ + try: + is_valid = ( + component_flavor + in ALLOWED_COMPONENT_TYPES[specs["provider"]][ + specs["component_type"] + ] + ) + except KeyError: + return False + + return is_valid diff --git a/src/mlstacks/utils/yaml_utils.py b/src/mlstacks/utils/yaml_utils.py index 422e5cdf..0ef0734a 100644 --- a/src/mlstacks/utils/yaml_utils.py +++ b/src/mlstacks/utils/yaml_utils.py @@ -16,6 +16,7 @@ import yaml +from mlstacks.constants import STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE from mlstacks.models.component import ( Component, ComponentMetadata, @@ -57,9 +58,17 @@ def load_component_yaml(path: str) -> Component: Returns: The component model. + + Raises: + FileNotFoundError: If the file is not found. """ - with open(path) as file: - component_data = yaml.safe_load(file) + try: + with open(path) as file: + component_data = yaml.safe_load(file) + except FileNotFoundError as exc: + error_message = f"""Component file at "{path}" specified in + the stack spec file could not be found.""" + raise FileNotFoundError(error_message) from exc if component_data.get("metadata") is None: component_data["metadata"] = {} @@ -88,6 +97,9 @@ def load_stack_yaml(path: str) -> Stack: Returns: The stack model. + + Raises: + ValueError: If the stack and component have different providers """ with open(path) as yaml_file: stack_data = yaml.safe_load(yaml_file) @@ -95,7 +107,8 @@ def load_stack_yaml(path: str) -> Stack: if component_data is None: component_data = [] - return Stack( + + stack = Stack( spec_version=stack_data.get("spec_version"), spec_type=stack_data.get("spec_type"), name=stack_data.get("name"), @@ -107,3 +120,9 @@ def load_stack_yaml(path: str) -> Stack: load_component_yaml(component) for component in component_data ], ) + + for component in stack.components: + if component.provider != stack.provider: + raise ValueError(STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE) + + return stack diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..6d675268 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Util functions for tests.""" + +from typing import List + +from mlstacks.enums import ProviderEnum + + +def get_allowed_providers() -> List[str]: + """Filter out unimplemented providers. + + Used for component and stack testing. + + Returns: + A list of implemented providers + """ + # Filter out AZURE + excluded_providers = ["azure"] + return [ + provider.value + for provider in ProviderEnum + if provider.value not in excluded_providers + ] diff --git a/tests/unit/models/test_component.py b/tests/unit/models/test_component.py index 3dec23fd..ce8f0b1d 100644 --- a/tests/unit/models/test_component.py +++ b/tests/unit/models/test_component.py @@ -11,14 +11,50 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st +from hypothesis.strategies import composite -from mlstacks.constants import PERMITTED_NAME_REGEX -from mlstacks.enums import ComponentFlavorEnum, ComponentTypeEnum +from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX +from mlstacks.enums import ( + ComponentFlavorEnum, + ComponentTypeEnum, + ProviderEnum, +) from mlstacks.models.component import Component, ComponentMetadata +@composite +def valid_components(draw): + # Drawing a valid provider enum member directly + provider = draw(st.sampled_from([provider for provider in ProviderEnum])) + + # component_types and component_flavors are mappings to strings, + # and model or validation layer handles string to enum conversion: + component_types = list(ALLOWED_COMPONENT_TYPES[provider.value].keys()) + assume(component_types) + component_type = draw(st.sampled_from(component_types)) + + component_flavors = ALLOWED_COMPONENT_TYPES[provider.value][component_type] + assume(component_flavors) + + component_flavor_str = draw(st.sampled_from(component_flavors)) + component_flavor_enum = ComponentFlavorEnum( + component_flavor_str + ) # Convert string to enum + + # Constructing the Component instance with valid fields + return Component( + name=draw(st.from_regex(PERMITTED_NAME_REGEX)), + provider=provider.value, + component_type=component_type, + component_flavor=component_flavor_enum, + spec_version=1, + spec_type="component", + metadata=None, + ) + + @given(st.builds(ComponentMetadata)) def test_component_metadata(instance): assert instance.config is None or isinstance(instance.config, dict) @@ -27,8 +63,9 @@ def test_component_metadata(instance): ) -@given(st.builds(Component, name=st.from_regex(PERMITTED_NAME_REGEX))) +@given(valid_components()) def test_component(instance): + print(f"instance: {instance}") assert isinstance(instance.spec_version, int) assert isinstance(instance.spec_type, str) assert isinstance(instance.name, str) diff --git a/tests/unit/utils/test_terraform_utils.py b/tests/unit/utils/test_terraform_utils.py index 5448d2b9..69abd74d 100644 --- a/tests/unit/utils/test_terraform_utils.py +++ b/tests/unit/utils/test_terraform_utils.py @@ -36,6 +36,7 @@ remote_state_bucket_exists, tf_definitions_present, ) +from tests.test_utils import get_allowed_providers EXISTING_S3_BUCKET_URL = "s3://public-flavor-logos" EXISTING_S3_BUCKET_REGION = "eu-central-1" @@ -111,11 +112,12 @@ def test_enable_key_function_handles_components_without_flavors( """ comp_flavor = "s3" comp_type = "artifact_store" + comp_provider = "aws" c = Component( name=dummy_name, component_flavor=comp_flavor, component_type=comp_type, - provider=random.choice(list(ProviderEnum)).value, + provider=comp_provider, ) key = _compose_enable_key(c) assert key == "enable_artifact_store" @@ -125,12 +127,16 @@ def test_component_variable_parsing_works(): """Tests that the component variable parsing works.""" metadata = ComponentMetadata() component_flavor = "zenml" + + allowed_providers = get_allowed_providers() + random_test = random.choice(allowed_providers) + components = [ Component( name="test", component_flavor=component_flavor, component_type="mlops_platform", - provider=random.choice(list(ProviderEnum)).value, + provider=random_test, spec_type="component", spec_version=1, metadata=metadata, @@ -146,12 +152,17 @@ def test_component_var_parsing_works_for_env_vars(): """Tests that the component variable parsing works.""" env_vars = {"ARIA_KEY": "blupus"} metadata = ComponentMetadata(environment_variables=env_vars) + + # EXCLUDE AZURE + allowed_providers = get_allowed_providers() + random_test = random.choice(allowed_providers) + components = [ Component( name="test", component_flavor="zenml", component_type="mlops_platform", - provider=random.choice(list(ProviderEnum)).value, + provider=random_test, metadata=metadata, ) ] @@ -165,7 +176,9 @@ def test_component_var_parsing_works_for_env_vars(): def test_tf_variable_parsing_from_stack_works(): """Tests that the Terraform variables extraction (from a stack) works.""" - provider = random.choice(list(ProviderEnum)).value + allowed_providers = get_allowed_providers() + provider = random.choice(allowed_providers) + component_flavor = "zenml" metadata = ComponentMetadata() components = [ diff --git a/tests/unit/utils/test_zenml_utils.py b/tests/unit/utils/test_zenml_utils.py index 7958914f..cc72771c 100644 --- a/tests/unit/utils/test_zenml_utils.py +++ b/tests/unit/utils/test_zenml_utils.py @@ -12,6 +12,7 @@ # permissions and limitations under the License. """Tests for utilities for mlstacks-ZenML interaction.""" + from mlstacks.models.component import Component from mlstacks.models.stack import Stack from mlstacks.utils.zenml_utils import has_valid_flavor_combinations @@ -53,7 +54,7 @@ def test_flavor_combination_validator_fails_aws_gcp(): name="blupus-component", component_type="artifact_store", component_flavor="gcp", - provider=valid_stack.provider, + provider="gcp", ) assert not has_valid_flavor_combinations( stack=valid_stack, @@ -75,7 +76,7 @@ def test_flavor_combination_validator_fails_k3d_s3(): name="blupus-component", component_type="artifact_store", component_flavor="s3", - provider=valid_stack.provider, + provider="aws", ) assert not has_valid_flavor_combinations( stack=valid_stack,