From 3e7173a003d44c007b22db30f7452b4544067af3 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Fri, 3 Jan 2025 18:18:20 -0800 Subject: [PATCH] feat(signal schema): serialize base classes for custom types --- src/datachain/lib/signal_schema.py | 38 ++++- tests/func/test_datachain.py | 13 ++ tests/unit/lib/test_signal_schema.py | 214 +++++++++++++++++++++++---- 3 files changed, 232 insertions(+), 33 deletions(-) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 1436cc584..44a5cee62 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -165,12 +165,25 @@ def _serialize_custom_model_fields( # This type is already stored in custom_types. return version_name fields = {} + for field_name, info in fr.model_fields.items(): field_type = info.annotation # All fields should be typed. assert field_type fields[field_name] = SignalSchema._serialize_type(field_type, custom_types) - custom_types[version_name] = fields + + bases: list[tuple[str, str, Optional[str]]] = [] + for type_ in fr.__mro__: + model_store_name = ( + ModelStore.get_name(type_) if issubclass(type_, DataModel) else None + ) + bases.append((type_.__name__, type_.__module__, model_store_name)) + + custom_types[version_name] = { + "_custom_types_schema_version": 2, + "fields": fields, + "bases": bases, + } return version_name @staticmethod @@ -184,7 +197,6 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str: if st is None or not ModelStore.is_pydantic(st): continue # Register and save feature types. - ModelStore.register(st) st_version_name = ModelStore.get_name(st) if st is fr: # If the main type is Pydantic, then use the ModelStore version name. @@ -227,6 +239,23 @@ def _split_subtypes(type_name: str) -> list[str]: subtypes.append(type_name[start:].strip()) return subtypes + @staticmethod + def _get_custom_type_fields( + type_name: str, custom_types: dict[str, Any] + ) -> dict[str, Any]: + custom_type_description = custom_types.get(type_name) + + if custom_type_description is None: + raise SignalSchemaError( + f"cannot deserialize '{type_name}' from custom types" + ) + + # Backward compatibility with the old custom types schema version that didn't + # include bases, and only had fields. + if "_custom_types_schema_version" not in custom_type_description: + return custom_type_description + return custom_type_description["fields"] + @staticmethod def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911 """Convert a string-based type back into a python type.""" @@ -273,7 +302,7 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type return fr if type_name in custom_types: - fields = custom_types[type_name] + fields = SignalSchema._get_custom_type_fields(type_name, custom_types) fields = { field_name: SignalSchema._resolve_type(field_type_str, custom_types) for field_name, field_type_str in fields.items() @@ -662,6 +691,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: stacklevel=2, ) return "Any" + if ModelStore.is_pydantic(type_): + ModelStore.register(type_) + return ModelStore.get_name(type_) return type_.__name__ @staticmethod diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index f4f504a4d..9e290c606 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -82,6 +82,19 @@ def test_from_storage_glob(cloud_test_catalog): assert dc.count() == 4 +def test_from_storage_simulated_directory(cloud_test_catalog): + ctc = cloud_test_catalog + catalog = ctc.catalog + src_uri = ctc.src_uri + client = catalog.get_client(src_uri) + fs = client.fs + fs.touch(f"{ctc.src_uri}/simulated-dir") + fs.mv(f"{ctc.src_uri}/simulated-dir", f"{ctc.src_uri}/simulated-dir/") + dc = DataChain.from_storage(f"{ctc.src_uri}/simulated-dir", session=ctc.session) + print(dc.show(3)) + assert dc.count() == 0 + + def test_from_storage_as_image(cloud_test_catalog): ctc = cloud_test_catalog dc = DataChain.from_storage(ctc.src_uri, session=ctc.session, type="image") diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 7a33e11f1..6cebcc0c7 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -129,8 +129,22 @@ def test_feature_schema_serialize_optional(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyType1, NoneType]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["feature"] == "Union[MyType1@v1, NoneType]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list(): @@ -142,8 +156,22 @@ def test_feature_schema_serialize_list(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema def test_feature_schema_serialize_list_old(): @@ -155,8 +183,27 @@ def test_feature_schema_serialize_list_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["features"] == "list[MyType1]" - assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + assert signals["features"] == "list[MyType1@v1]" + assert signals["_custom_types"] == { + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + } + } + + new_schema = { + "name": Optional[str], + "features": list[MyType1], + } + + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == new_schema def test_feature_schema_serialize_nested_types(): @@ -168,12 +215,33 @@ def test_feature_schema_serialize_nested_types(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_nested_duplicate_types(): schema = { @@ -185,13 +253,34 @@ def test_feature_schema_serialize_nested_duplicate_types(): assert len(signals) == 4 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature_nested"] == "Union[MyType2, NoneType]" - assert signals["feature_not_nested"] == "Union[MyType1, NoneType]" + assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]" + assert signals["feature_not_nested"] == "Union[MyType1@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex(): schema = { @@ -202,17 +291,51 @@ def test_feature_schema_serialize_complex(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplex, NoneType]" + assert signals["feature"] == "Union[MyTypeComplex@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplex@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "_custom_types_schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "bases": [ + ( + "MyTypeComplex", + "tests.unit.lib.test_signal_schema", + "MyTypeComplex@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } + deserialized_schema = SignalSchema.deserialize(signals) + assert deserialized_schema.values == schema + def test_feature_schema_serialize_complex_old(): schema = { @@ -223,14 +346,45 @@ def test_feature_schema_serialize_complex_old(): assert len(signals) == 3 assert signals["name"] == "Union[str, NoneType]" - assert signals["feature"] == "Union[MyTypeComplexOld, NoneType]" + assert signals["feature"] == "Union[MyTypeComplexOld@v1, NoneType]" assert signals["_custom_types"] == { - "MyType1@v1": {"aa": "int", "bb": "str"}, - "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyType1@v1": { + "_custom_types_schema_version": 2, + "fields": {"aa": "int", "bb": "str"}, + "bases": [ + ("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, + "MyType2@v1": { + "_custom_types_schema_version": 2, + "fields": {"name": "str", "deep": "MyType1@v1"}, + "bases": [ + ("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], + }, "MyTypeComplexOld@v1": { - "name": "str", - "items": "list[MyType1]", - "lookup": "dict[str, MyType2]", + "_custom_types_schema_version": 2, + "fields": { + "name": "str", + "items": "list[MyType1@v1]", + "lookup": "dict[str, MyType2@v1]", + }, + "bases": [ + ( + "MyTypeComplexOld", + "tests.unit.lib.test_signal_schema", + "MyTypeComplexOld@v1", + ), + ("DataModel", "datachain.lib.data_model", "DataModel@v1"), + ("BaseModel", "pydantic.main", None), + ("object", "builtins", None), + ], }, } @@ -502,14 +656,14 @@ def test_print_types(): int: "int", float: "float", None: "NoneType", - MyType2: "MyType2", + MyType2: "MyType2@v1", Any: "Any", Literal: "Literal", Final: "Final", - Optional[MyType2]: "Union[MyType2, NoneType]", + Optional[MyType2]: "Union[MyType2@v1, NoneType]", Union[str, int]: "Union[str, int]", Union[str, int, bool]: "Union[str, int, bool]", - Union[Optional[MyType2]]: "Union[MyType2, NoneType]", + Union[Optional[MyType2]]: "Union[MyType2@v1, NoneType]", list: "list", list[bool]: "list[bool]", List[bool]: "list[bool]", # noqa: UP006 @@ -518,8 +672,8 @@ def test_print_types(): dict: "dict", dict[str, bool]: "dict[str, bool]", Dict[str, bool]: "dict[str, bool]", # noqa: UP006 - dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", - Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", # noqa: UP006 + dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", + Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", # noqa: UP006 Union[str, list[str]]: "Union[str, list[str]]", Union[str, List[str]]: "Union[str, list[str]]", # noqa: UP006 Optional[Literal["x"]]: "Union[Literal, NoneType]",