From 67b984355f53e57d23ba39c4a525818ba98c2113 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Fri, 3 Jan 2025 18:18:20 -0800 Subject: [PATCH] feat(signal schema): serialize base classes for custom types --- src/datachain/lib/signal_schema.py | 70 ++++++++- tests/unit/lib/test_signal_schema.py | 214 +++++++++++++++++++++++---- 2 files changed, 247 insertions(+), 37 deletions(-) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 1436cc584..89b309029 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -86,7 +86,9 @@ def __init__(self, method: str, field): def create_feature_model( - name: str, fields: dict[str, Union[type, tuple[type, Any]]] + name: str, + fields: dict[str, Union[type, tuple[type, Any]]], + base: Optional[type] = None, ) -> type[BaseModel]: """ This gets or returns a dynamic feature model for use in restoring a model @@ -98,7 +100,7 @@ def create_feature_model( name = name.replace("@", "_") return create_model( name, - __base__=DataModel, # type: ignore[call-overload] + __base__=base or DataModel, # type: ignore[call-overload] # These are tuples for each field of: annotation, default (if any) **{ field_name: anno if isinstance(anno, tuple) else (anno, None) @@ -165,12 +167,25 @@ def _serialize_custom_model_fields( # This type is already stored in custom_types. return version_name fields = {} + for field_name, info in fr.model_fields.items(): field_type = info.annotation # All fields should be typed. assert field_type fields[field_name] = SignalSchema._serialize_type(field_type, custom_types) - custom_types[version_name] = fields + + bases: list[tuple[str, str, Optional[str]]] = [] + for type_ in fr.__mro__: + model_store_name = ( + ModelStore.get_name(type_) if issubclass(type_, DataModel) else None + ) + bases.append((type_.__name__, type_.__module__, model_store_name)) + + custom_types[version_name] = { + "_custom_types_schema_version": 2, + "fields": fields, + "bases": bases, + } return version_name @staticmethod @@ -184,7 +199,6 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str: if st is None or not ModelStore.is_pydantic(st): continue # Register and save feature types. - ModelStore.register(st) st_version_name = ModelStore.get_name(st) if st is fr: # If the main type is Pydantic, then use the ModelStore version name. @@ -228,7 +242,29 @@ def _split_subtypes(type_name: str) -> list[str]: return subtypes @staticmethod - def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911 + def _get_custom_type_attr( + attr_name: str, + type_name: str, + custom_types: dict[str, Any], + default: Any = None, + ) -> Any: + custom_type_description = custom_types.get(type_name) + + if custom_type_description is None: + raise SignalSchemaError( + f"cannot deserialize '{type_name}' from custom types" + ) + + # Backward compatibility with the old custom types schema version that didn't + # include bases, and only had fields. + if "_custom_types_schema_version" not in custom_type_description: + if attr_name == "fields": + return custom_type_description + return default + return custom_type_description[attr_name] + + @staticmethod + def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911, C901 """Convert a string-based type back into a python type.""" type_name = type_name.strip() if not type_name: @@ -273,12 +309,29 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type return fr if type_name in custom_types: - fields = custom_types[type_name] + fields = SignalSchema._get_custom_type_attr( + "fields", type_name, custom_types + ) + bases = SignalSchema._get_custom_type_attr( + "bases", type_name, custom_types, default=[] + ) fields = { field_name: SignalSchema._resolve_type(field_type_str, custom_types) for field_name, field_type_str in fields.items() } - return create_feature_model(type_name, fields) + + base_model = None + for base in bases: + _, _, model_store_name = base + if model_store_name: + model_name, version = ModelStore.parse_name_version( + model_store_name + ) + base_model = ModelStore.get(model_name, version) + if base_model: + break + + return create_feature_model(type_name, fields, base=base_model) # This can occur if a third-party or custom type is used, which is not available # when deserializing. warnings.warn( @@ -662,6 +715,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: stacklevel=2, ) return "Any" + if ModelStore.is_pydantic(type_): + ModelStore.register(type_) + return ModelStore.get_name(type_) return type_.__name__ @staticmethod diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 7a33e11f1..6cebcc0c7 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -129,8 +129,22 @@ def test_feature_schema_serialize_optional(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyType1, NoneType]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["feature"] == "Union[MyType1@v1, NoneType]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list(): @@ -142,8 +156,22 @@ def test_feature_schema_serialize_list(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list_old(): @@ -155,8 +183,27 @@ def test_feature_schema_serialize_list_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + new_schema = { + "name": Optional[str], + "features": list[MyType1], + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == new_schema def test_feature_schema_serialize_nested_types(): @@ -168,12 +215,33 @@ def test_feature_schema_serialize_nested_types(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_nested_duplicate_types(): schema = { @@ -185,13 +253,34 @@ def test_feature_schema_serialize_nested_duplicate_types(): assert len(signals) == 4 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" - assert signals["feature_not_nested"] == "Union[MyType1, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" + assert signals["feature_not_nested"] == "Union[MyType1@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex(): schema = { @@ -202,17 +291,51 @@ def test_feature_schema_serialize_complex(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplex, NoneType]" + assert signals["feature"] == "Union[MyTypeComplex@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplex@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "_custom_types_schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "bases": [ + ( + "MyTypeComplex", + "tests.unit.lib.test_signal_schema", + "MyTypeComplex@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex_old(): schema = { @@ -223,14 +346,45 @@ def test_feature_schema_serialize_complex_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplexOld, NoneType]" + assert signals["feature"] == "Union[MyTypeComplexOld@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplexOld@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "_custom_types_schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "bases": [ + ( + "MyTypeComplexOld", + "tests.unit.lib.test_signal_schema", + "MyTypeComplexOld@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } @@ -502,14 +656,14 @@ def test_print_types(): int: "int", float: "float", None: "NoneType", - MyType2: "MyType2", + MyType2: "MyType2@v1", Any: "Any", Literal: "Literal", Final: "Final", - Optional[MyType2]: "Union[MyType2, NoneType]", + Optional[MyType2]: "Union[MyType2@v1, NoneType]", Union[str, int]: "Union[str, int]", Union[str, int, bool]: "Union[str, int, bool]", - Union[Optional[MyType2]]: "Union[MyType2, NoneType]", + Union[Optional[MyType2]]: "Union[MyType2@v1, NoneType]", list: "list", list[bool]: "list[bool]", List[bool]: "list[bool]", # noqa: UP006 @@ -518,8 +672,8 @@ def test_print_types(): dict: "dict", dict[str, bool]: "dict[str, bool]", Dict[str, bool]: "dict[str, bool]", # noqa: UP006 - dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", - Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", # noqa: UP006 + dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", + Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", # noqa: UP006 Union[str, list[str]]: "Union[str, list[str]]", Union[str, List[str]]: "Union[str, list[str]]", # noqa: UP006 Optional[Literal["x"]]: "Union[Literal, NoneType]",