Skip to content

Commit

Permalink
feat(signal schema): serialize base classes for custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 4, 2025
1 parent 8dfa4ff commit 3e7173a
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 33 deletions.
38 changes: 35 additions & 3 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,25 @@ def _serialize_custom_model_fields(
# This type is already stored in custom_types.
return version_name
fields = {}

for field_name, info in fr.model_fields.items():
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields

bases: list[tuple[str, str, Optional[str]]] = []
for type_ in fr.__mro__:
model_store_name = (
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
)
bases.append((type_.__name__, type_.__module__, model_store_name))

custom_types[version_name] = {
"_custom_types_schema_version": 2,
"fields": fields,
"bases": bases,
}
return version_name

@staticmethod
Expand All @@ -184,7 +197,6 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
Expand Down Expand Up @@ -227,6 +239,23 @@ def _split_subtypes(type_name: str) -> list[str]:
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _get_custom_type_fields(
type_name: str, custom_types: dict[str, Any]
) -> dict[str, Any]:
custom_type_description = custom_types.get(type_name)

if custom_type_description is None:
raise SignalSchemaError(
f"cannot deserialize '{type_name}' from custom types"
)

# Backward compatibility with the old custom types schema version that didn't
# include bases, and only had fields.
if "_custom_types_schema_version" not in custom_type_description:
return custom_type_description
return custom_type_description["fields"]

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
"""Convert a string-based type back into a python type."""
Expand Down Expand Up @@ -273,7 +302,7 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
return fr

if type_name in custom_types:
fields = custom_types[type_name]
fields = SignalSchema._get_custom_type_fields(type_name, custom_types)
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
Expand Down Expand Up @@ -662,6 +691,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str:
stacklevel=2,
)
return "Any"
if ModelStore.is_pydantic(type_):
ModelStore.register(type_)
return ModelStore.get_name(type_)
return type_.__name__

@staticmethod
Expand Down
13 changes: 13 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def test_from_storage_glob(cloud_test_catalog):
assert dc.count() == 4


def test_from_storage_simulated_directory(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
src_uri = ctc.src_uri
client = catalog.get_client(src_uri)
fs = client.fs
fs.touch(f"{ctc.src_uri}/simulated-dir")
fs.mv(f"{ctc.src_uri}/simulated-dir", f"{ctc.src_uri}/simulated-dir/")
dc = DataChain.from_storage(f"{ctc.src_uri}/simulated-dir", session=ctc.session)
print(dc.show(3))
assert dc.count() == 0


def test_from_storage_as_image(cloud_test_catalog):
ctc = cloud_test_catalog
dc = DataChain.from_storage(ctc.src_uri, session=ctc.session, type="image")
Expand Down
214 changes: 184 additions & 30 deletions tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,22 @@ def test_feature_schema_serialize_optional():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["feature"] == "Union[MyType1, NoneType]"
assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}}
assert signals["feature"] == "Union[MyType1@v1, NoneType]"
assert signals["_custom_types"] == {
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
}
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == schema


def test_feature_schema_serialize_list():
Expand All @@ -142,8 +156,22 @@ def test_feature_schema_serialize_list():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["features"] == "list[MyType1]"
assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}}
assert signals["features"] == "list[MyType1@v1]"
assert signals["_custom_types"] == {
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
}
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == schema


def test_feature_schema_serialize_list_old():
Expand All @@ -155,8 +183,27 @@ def test_feature_schema_serialize_list_old():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["features"] == "list[MyType1]"
assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}}
assert signals["features"] == "list[MyType1@v1]"
assert signals["_custom_types"] == {
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
}
}

new_schema = {
"name": Optional[str],
"features": list[MyType1],
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == new_schema


def test_feature_schema_serialize_nested_types():
Expand All @@ -168,12 +215,33 @@ def test_feature_schema_serialize_nested_types():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["feature_nested"] == "Union[MyType2, NoneType]"
assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]"
assert signals["_custom_types"] == {
"MyType1@v1": {"aa": "int", "bb": "str"},
"MyType2@v1": {"deep": "MyType1@v1", "name": "str"},
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyType2@v1": {
"_custom_types_schema_version": 2,
"fields": {"name": "str", "deep": "MyType1@v1"},
"bases": [
("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == schema


def test_feature_schema_serialize_nested_duplicate_types():
schema = {
Expand All @@ -185,13 +253,34 @@ def test_feature_schema_serialize_nested_duplicate_types():

assert len(signals) == 4
assert signals["name"] == "Union[str, NoneType]"
assert signals["feature_nested"] == "Union[MyType2, NoneType]"
assert signals["feature_not_nested"] == "Union[MyType1, NoneType]"
assert signals["feature_nested"] == "Union[MyType2@v1, NoneType]"
assert signals["feature_not_nested"] == "Union[MyType1@v1, NoneType]"
assert signals["_custom_types"] == {
"MyType1@v1": {"aa": "int", "bb": "str"},
"MyType2@v1": {"deep": "MyType1@v1", "name": "str"},
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyType2@v1": {
"_custom_types_schema_version": 2,
"fields": {"name": "str", "deep": "MyType1@v1"},
"bases": [
("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == schema


def test_feature_schema_serialize_complex():
schema = {
Expand All @@ -202,17 +291,51 @@ def test_feature_schema_serialize_complex():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["feature"] == "Union[MyTypeComplex, NoneType]"
assert signals["feature"] == "Union[MyTypeComplex@v1, NoneType]"
assert signals["_custom_types"] == {
"MyType1@v1": {"aa": "int", "bb": "str"},
"MyType2@v1": {"deep": "MyType1@v1", "name": "str"},
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyType2@v1": {
"_custom_types_schema_version": 2,
"fields": {"name": "str", "deep": "MyType1@v1"},
"bases": [
("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyTypeComplex@v1": {
"name": "str",
"items": "list[MyType1]",
"lookup": "dict[str, MyType2]",
"_custom_types_schema_version": 2,
"fields": {
"name": "str",
"items": "list[MyType1@v1]",
"lookup": "dict[str, MyType2@v1]",
},
"bases": [
(
"MyTypeComplex",
"tests.unit.lib.test_signal_schema",
"MyTypeComplex@v1",
),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
}

deserialized_schema = SignalSchema.deserialize(signals)
assert deserialized_schema.values == schema


def test_feature_schema_serialize_complex_old():
schema = {
Expand All @@ -223,14 +346,45 @@ def test_feature_schema_serialize_complex_old():

assert len(signals) == 3
assert signals["name"] == "Union[str, NoneType]"
assert signals["feature"] == "Union[MyTypeComplexOld, NoneType]"
assert signals["feature"] == "Union[MyTypeComplexOld@v1, NoneType]"
assert signals["_custom_types"] == {
"MyType1@v1": {"aa": "int", "bb": "str"},
"MyType2@v1": {"deep": "MyType1@v1", "name": "str"},
"MyType1@v1": {
"_custom_types_schema_version": 2,
"fields": {"aa": "int", "bb": "str"},
"bases": [
("MyType1", "tests.unit.lib.test_signal_schema", "MyType1@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyType2@v1": {
"_custom_types_schema_version": 2,
"fields": {"name": "str", "deep": "MyType1@v1"},
"bases": [
("MyType2", "tests.unit.lib.test_signal_schema", "MyType2@v1"),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
"MyTypeComplexOld@v1": {
"name": "str",
"items": "list[MyType1]",
"lookup": "dict[str, MyType2]",
"_custom_types_schema_version": 2,
"fields": {
"name": "str",
"items": "list[MyType1@v1]",
"lookup": "dict[str, MyType2@v1]",
},
"bases": [
(
"MyTypeComplexOld",
"tests.unit.lib.test_signal_schema",
"MyTypeComplexOld@v1",
),
("DataModel", "datachain.lib.data_model", "DataModel@v1"),
("BaseModel", "pydantic.main", None),
("object", "builtins", None),
],
},
}

Expand Down Expand Up @@ -502,14 +656,14 @@ def test_print_types():
int: "int",
float: "float",
None: "NoneType",
MyType2: "MyType2",
MyType2: "MyType2@v1",
Any: "Any",
Literal: "Literal",
Final: "Final",
Optional[MyType2]: "Union[MyType2, NoneType]",
Optional[MyType2]: "Union[MyType2@v1, NoneType]",
Union[str, int]: "Union[str, int]",
Union[str, int, bool]: "Union[str, int, bool]",
Union[Optional[MyType2]]: "Union[MyType2, NoneType]",
Union[Optional[MyType2]]: "Union[MyType2@v1, NoneType]",
list: "list",
list[bool]: "list[bool]",
List[bool]: "list[bool]", # noqa: UP006
Expand All @@ -518,8 +672,8 @@ def test_print_types():
dict: "dict",
dict[str, bool]: "dict[str, bool]",
Dict[str, bool]: "dict[str, bool]", # noqa: UP006
dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]",
Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", # noqa: UP006
dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]",
Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1@v1, NoneType]]", # noqa: UP006
Union[str, list[str]]: "Union[str, list[str]]",
Union[str, List[str]]: "Union[str, list[str]]", # noqa: UP006
Optional[Literal["x"]]: "Union[Literal, NoneType]",
Expand Down

0 comments on commit 3e7173a

Please sign in to comment.