From 7d188a638fbbda0ae109586176aad67092fb9060 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Fri, 3 Jan 2025 18:18:20 -0800 Subject: [PATCH] feat(signal schema): serialize base classes for custom types --- src/datachain/lib/signal_schema.py | 116 +++++++--- tests/unit/lib/test_signal_schema.py | 312 ++++++++++++++++++++++++--- 2 files changed, 369 insertions(+), 59 deletions(-) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 1436cc584..d723c5b8d 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -13,13 +13,14 @@ Final, List, Literal, + Mapping, Optional, Union, get_args, get_origin, ) -from pydantic import BaseModel, create_model +from pydantic import BaseModel, Field, create_model from sqlalchemy import ColumnElement from typing_extensions import Literal as LiteralEx @@ -85,8 +86,31 @@ def __init__(self, method: str, field): ) +class CustomType(BaseModel): + schema_version: int = Field(ge=1, le=2, strict=True) + name: str + fields: dict[str, str] + bases: list[tuple[str, str, Optional[str]]] + + @classmethod + def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType": + version = data.get("schema_version", 1) + + if version == 1: + data = { + "schema_version": 1, + "name": type_name, + "fields": data, + "bases": [], + } + + return cls(**data) + + def create_feature_model( - name: str, fields: dict[str, Union[type, tuple[type, Any]]] + name: str, + fields: Mapping[str, Union[type, None, tuple[type, Any]]], + base: Optional[type] = None, ) -> type[BaseModel]: """ This gets or returns a dynamic feature model for use in restoring a model @@ -98,7 +122,7 @@ def create_feature_model( name = name.replace("@", "_") return create_model( name, - __base__=DataModel, # type: ignore[call-overload] + __base__=base or DataModel, # type: ignore[call-overload] # These are tuples for each field of: annotation, default (if any) **{ field_name: anno if isinstance(anno, tuple) else (anno, None) @@ -156,7 +180,7 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema": return SignalSchema(signals) @staticmethod - def _serialize_custom_model_fields( + def _serialize_custom_model( version_name: str, fr: type[BaseModel], custom_types: dict[str, Any] ) -> str: """This serializes any custom type information to the provided custom_types @@ -165,12 +189,23 @@ def _serialize_custom_model_fields( # This type is already stored in custom_types. return version_name fields = {} + for field_name, info in fr.model_fields.items(): field_type = info.annotation # All fields should be typed. assert field_type fields[field_name] = SignalSchema._serialize_type(field_type, custom_types) - custom_types[version_name] = fields + + bases: list[tuple[str, str, Optional[str]]] = [] + for type_ in fr.__mro__: + model_store_name = ( + ModelStore.get_name(type_) if issubclass(type_, DataModel) else None + ) + bases.append((type_.__name__, type_.__module__, model_store_name)) + + ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases) + custom_types[version_name] = ct.model_dump() + return version_name @staticmethod @@ -184,15 +219,12 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str: if st is None or not ModelStore.is_pydantic(st): continue # Register and save feature types. - ModelStore.register(st) st_version_name = ModelStore.get_name(st) if st is fr: # If the main type is Pydantic, then use the ModelStore version name. type_name = st_version_name # Save this type to custom_types. - SignalSchema._serialize_custom_model_fields( - st_version_name, st, custom_types - ) + SignalSchema._serialize_custom_model(st_version_name, st, custom_types) return type_name def serialize(self) -> dict[str, Any]: @@ -215,7 +247,7 @@ def _split_subtypes(type_name: str) -> list[str]: depth += 1 elif c == "]": if depth == 0: - raise TypeError( + raise ValueError( "Extra closing square bracket when parsing subtype list" ) depth -= 1 @@ -223,16 +255,51 @@ def _split_subtypes(type_name: str) -> list[str]: subtypes.append(type_name[start:i].strip()) start = i + 1 if depth > 0: - raise TypeError("Unclosed square bracket when parsing subtype list") + raise ValueError("Unclosed square bracket when parsing subtype list") subtypes.append(type_name[start:].strip()) return subtypes @staticmethod - def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911 + def _deserialize_custom_type( + type_name: str, custom_types: dict[str, Any] + ) -> Optional[type]: + """Given a type name like MyType@v1 gets a type from ModelStore or recreates + it based on the information from the custom types dict that includes fields and + bases.""" + model_name, version = ModelStore.parse_name_version(type_name) + fr = ModelStore.get(model_name, version) + if fr: + return fr + + if type_name in custom_types: + ct = CustomType.deserialize(custom_types[type_name], type_name) + + fields = { + field_name: SignalSchema._resolve_type(field_type_str, custom_types) + for field_name, field_type_str in ct.fields.items() + } + + base_model = None + for base in ct.bases: + _, _, model_store_name = base + if model_store_name: + model_name, version = ModelStore.parse_name_version( + model_store_name + ) + base_model = ModelStore.get(model_name, version) + if base_model: + break + + return create_feature_model(type_name, fields, base=base_model) + + return None + + @staticmethod + def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: """Convert a string-based type back into a python type.""" type_name = type_name.strip() if not type_name: - raise TypeError("Type cannot be empty") + raise ValueError("Type cannot be empty") if type_name == "NoneType": return None @@ -240,14 +307,14 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type subtypes: Optional[tuple[Optional[type], ...]] = None if bracket_idx > -1: if bracket_idx == 0: - raise TypeError("Type cannot start with '['") + raise ValueError("Type cannot start with '['") close_bracket_idx = type_name.rfind("]") if close_bracket_idx == -1: - raise TypeError("Unclosed square bracket when parsing type") + raise ValueError("Unclosed square bracket when parsing type") if close_bracket_idx < bracket_idx: - raise TypeError("Square brackets are out of order when parsing type") + raise ValueError("Square brackets are out of order when parsing type") if close_bracket_idx == bracket_idx + 1: - raise TypeError("Empty square brackets when parsing type") + raise ValueError("Empty square brackets when parsing type") subtype_names = SignalSchema._split_subtypes( type_name[bracket_idx + 1 : close_bracket_idx] ) @@ -267,18 +334,10 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type return fr[subtypes] # type: ignore[index] return fr # type: ignore[return-value] - model_name, version = ModelStore.parse_name_version(type_name) - fr = ModelStore.get(model_name, version) + fr = SignalSchema._deserialize_custom_type(type_name, custom_types) if fr: return fr - if type_name in custom_types: - fields = custom_types[type_name] - fields = { - field_name: SignalSchema._resolve_type(field_type_str, custom_types) - for field_name, field_type_str in fields.items() - } - return create_feature_model(type_name, fields) # This can occur if a third-party or custom type is used, which is not available # when deserializing. warnings.warn( @@ -317,7 +376,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema": stacklevel=2, ) continue - except TypeError as err: + except ValueError as err: raise SignalSchemaError( f"cannot deserialize '{signal}': {err}" ) from err @@ -662,6 +721,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: stacklevel=2, ) return "Any" + if ModelStore.is_pydantic(type_): + ModelStore.register(type_) + return ModelStore.get_name(type_) return type_.__name__ @staticmethod diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 7a33e11f1..cef421b5d 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -7,6 +7,7 @@ from datachain import Column, DataModel from datachain.lib.convert.flatten import flatten from datachain.lib.file import File, TextFile +from datachain.lib.model_store import ModelStore from datachain.lib.signal_schema import ( SetupError, SignalResolvingError, @@ -53,6 +54,10 @@ class MyType2(DataModel): deep: MyType1 +class MyType3(MyType1): + name: str + + class MyTypeComplex(DataModel): name: str items: list[MyType1] @@ -129,8 +134,23 @@ def test_feature_schema_serialize_optional(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyType1, NoneType]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["feature"] == "Union[MyType1@v1, NoneType]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list(): @@ -142,8 +162,23 @@ def test_feature_schema_serialize_list(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list_old(): @@ -155,8 +190,28 @@ def test_feature_schema_serialize_list_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + new_schema = { + "name": Optional[str], + "features": list[MyType1], + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == new_schema def test_feature_schema_serialize_nested_types(): @@ -168,12 +223,35 @@ def test_feature_schema_serialize_nested_types(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "name": "MyType2@v1", + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_nested_duplicate_types(): schema = { @@ -185,13 +263,36 @@ def test_feature_schema_serialize_nested_duplicate_types(): assert len(signals) == 4 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" - assert signals["feature_not_nested"] == "Union[MyType1, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" + assert signals["feature_not_nested"] == "Union[MyType1@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "name": "MyType2@v1", + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex(): schema = { @@ -202,17 +303,54 @@ def test_feature_schema_serialize_complex(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplex, NoneType]" + assert signals["feature"] == "Union[MyTypeComplex@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "name": "MyType2@v1", + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplex@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "name": "MyTypeComplex@v1", + "bases": [ + ( + "MyTypeComplex", + "tests.unit.lib.test_signal_schema", + "MyTypeComplex@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex_old(): schema = { @@ -223,14 +361,48 @@ def test_feature_schema_serialize_complex_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplexOld, NoneType]" + assert signals["feature"] == "Union[MyTypeComplexOld@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "name": "MyType1@v1", + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "name": "MyType2@v1", + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplexOld@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "name": "MyTypeComplexOld@v1", + "bases": [ + ( + "MyTypeComplexOld", + "tests.unit.lib.test_signal_schema", + "MyTypeComplexOld@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } @@ -294,12 +466,14 @@ def test_select(): assert signals["f.bb"] is str -def test_select_custom_type(): +def test_select_custom_type_backward_compatibility(): schema = SignalSchema.deserialize( { "age": "float", "address": "str", "f": "ExternalCustomType1@v1", + # Older custom types schema is supported + # Can be removed a bit later "_custom_types": {"ExternalCustomType1@v1": {"aa": "int", "bb": "str"}}, } ) @@ -315,6 +489,80 @@ def test_select_custom_type(): assert signals["f.bb"] is str +def test_select_custom_type(): + schema = SignalSchema.deserialize( + { + "age": "float", + "address": "str", + "f": "ExternalCustomType1@v1", + "_custom_types": { + "ExternalCustomType1@v1": { + "schema_version": 2, + "name": "ExternalCustomType1@v1", + "fields": {"aa": "int", "bb": "str"}, + "bases": [], + }, + }, + } + ) + + new = schema.resolve("age", "f.aa", "f.bb") + assert isinstance(new, SignalSchema) + + signals = new.values + assert len(signals) == 3 + assert {"age", "f.aa", "f.bb"} == signals.keys() + assert signals["age"] is float + assert signals["f.aa"] is int + assert signals["f.bb"] is str + + +def test_deserialize_restores_known_base_type(): + schema = {"fr": MyType3} + signals = SignalSchema(schema).serialize() + ModelStore.remove(MyType3) + + # Seince MyType3 is removed, deserialization restores it + # from the meta information stored in the schema, including the base type + # that is still known - MyType1 + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values["fr"].__name__ == "MyType3_v1" + assert issubclass(deserialized_schema.values["fr"], MyType1) + + +def test_deserialize_custom_type_bad_schema(): + # No `bases` field + with pytest.raises(SignalSchemaError): + SignalSchema.deserialize( + { + "f": "ExternalCustomType1@v1", + "_custom_types": { + "ExternalCustomType1@v1": { + "schema_version": 2, + "name": "ExternalCustomType1@v1", + "fields": {"aa": "int", "bb": "str"}, + }, + }, + } + ) + + # Bad version + with pytest.raises(SignalSchemaError): + SignalSchema.deserialize( + { + "f": "ExternalCustomType1@v1", + "_custom_types": { + "ExternalCustomType1@v1": { + "schema_version": 123, + "name": "ExternalCustomType1@v1", + "fields": {"aa": "int", "bb": "str"}, + "bases": [], + }, + }, + } + ) + + def test_select_nested_names(): schema = SignalSchema.deserialize( { @@ -502,14 +750,14 @@ def test_print_types(): int: "int", float: "float", None: "NoneType", - MyType2: "MyType2", + MyType2: "MyType2@v1", Any: "Any", Literal: "Literal", Final: "Final", - Optional[MyType2]: "Union[MyType2, NoneType]", + Optional[MyType2]: "Union[MyType2@v1, NoneType]", Union[str, int]: "Union[str, int]", Union[str, int, bool]: "Union[str, int, bool]", - Union[Optional[MyType2]]: "Union[MyType2, NoneType]", + Union[Optional[MyType2]]: "Union[MyType2@v1, NoneType]", list: "list", list[bool]: "list[bool]", List[bool]: "list[bool]", # noqa: UP006 @@ -518,8 +766,8 @@ def test_print_types(): dict: "dict", dict[str, bool]: "dict[str, bool]", Dict[str, bool]: "dict[str, bool]", # noqa: UP006 - dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", - Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", # noqa: UP006 + dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", + Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", # noqa: UP006 Union[str, list[str]]: "Union[str, list[str]]", Union[str, List[str]]: "Union[str, list[str]]", # noqa: UP006 Optional[Literal["x"]]: "Union[Literal, NoneType]", @@ -604,7 +852,7 @@ def test_resolve_types_errors(): } for t, m in bogus_types_messages.items(): - with pytest.raises(TypeError, match=m): + with pytest.raises(ValueError, match=m): SignalSchema._resolve_type(t, {})