Skip to content

Commit

Permalink
feat(signal schema): serialize base classes for custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 4, 2025
1 parent 8dfa4ff commit a7d5da1
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 36 deletions.
67 changes: 61 additions & 6 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __init__(self, method: str, field):


def create_feature_model(
name: str, fields: dict[str, Union[type, tuple[type, Any]]]
name: str,
fields: dict[str, Union[type, tuple[type, Any]]],
base: Optional[type] = None,
) -> type[BaseModel]:
"""
This gets or returns a dynamic feature model for use in restoring a model
Expand All @@ -98,7 +100,7 @@ def create_feature_model(
name = name.replace("@", "_")
return create_model(
name,
__base__=DataModel, # type: ignore[call-overload]
__base__=base or DataModel, # type: ignore[call-overload]
# These are tuples for each field of: annotation, default (if any)
**{
field_name: anno if isinstance(anno, tuple) else (anno, None)
Expand Down Expand Up @@ -165,12 +167,25 @@ def _serialize_custom_model_fields(
# This type is already stored in custom_types.
return version_name
fields = {}

for field_name, info in fr.model_fields.items():
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields

bases: list[tuple[str, str, Optional[str]]] = []
for type_ in fr.__mro__:
model_store_name = (
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
)
bases.append((type_.__name__, type_.__module__, model_store_name))

custom_types[version_name] = {
"_custom_types_schema_version": 2,
"fields": fields,
"bases": bases,
}
return version_name

@staticmethod
Expand All @@ -184,7 +199,6 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
Expand Down Expand Up @@ -227,6 +241,28 @@ def _split_subtypes(type_name: str) -> list[str]:
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _get_custom_type_attr(
attr_name: str,
type_name: str,
custom_types: dict[str, Any],
default: Any = None,
) -> Any:
custom_type_description = custom_types.get(type_name)

if custom_type_description is None:
raise SignalSchemaError(

Check warning on line 254 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L254

Added line #L254 was not covered by tests
f"cannot deserialize '{type_name}' from custom types"
)

# Backward compatibility with the old custom types schema version that didn't
# include bases, and only had fields.
if "_custom_types_schema_version" not in custom_type_description:
if attr_name == "fields":
return custom_type_description
return default
return custom_type_description[attr_name]

Check warning on line 264 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L264

Added line #L264 was not covered by tests

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
"""Convert a string-based type back into a python type."""
Expand Down Expand Up @@ -273,12 +309,28 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
return fr

if type_name in custom_types:
fields = custom_types[type_name]
fields = SignalSchema._get_custom_type_attr(
"fields", type_name, custom_types
)
bases = SignalSchema._get_custom_type_attr(
"bases", type_name, custom_types, default=[]
)
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)

base_model = None
for base in bases:
_, _, base_model_store_name = base
model_name, version = ModelStore.parse_name_version(

Check warning on line 326 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L325-L326

Added lines #L325 - L326 were not covered by tests
base_model_store_name
)
base_model = ModelStore.get(model_name, version)

Check warning on line 329 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L329

Added line #L329 was not covered by tests
if base_model:
break

Check warning on line 331 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L331

Added line #L331 was not covered by tests

return create_feature_model(type_name, fields, base=base_model)
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
Expand Down Expand Up @@ -662,6 +714,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str:
stacklevel=2,
)
return "Any"
if ModelStore.is_pydantic(type_):
ModelStore.register(type_)
return ModelStore.get_name(type_)
return type_.__name__

@staticmethod
Expand Down
Loading

0 comments on commit a7d5da1

Please sign in to comment.