Skip to content

Commit

Permalink
feat(signal schema): serialize base classes for custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 4, 2025
1 parent 8dfa4ff commit a8e673f
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 49 deletions.
108 changes: 89 additions & 19 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __init__(self, method: str, field):


def create_feature_model(
name: str, fields: dict[str, Union[type, tuple[type, Any]]]
name: str,
fields: dict[str, Union[type, tuple[type, Any]]],
base: Optional[type] = None,
) -> type[BaseModel]:
"""
This gets or returns a dynamic feature model for use in restoring a model
Expand All @@ -98,7 +100,7 @@ def create_feature_model(
name = name.replace("@", "_")
return create_model(
name,
__base__=DataModel, # type: ignore[call-overload]
__base__=base or DataModel, # type: ignore[call-overload]
# These are tuples for each field of: annotation, default (if any)
**{
field_name: anno if isinstance(anno, tuple) else (anno, None)
Expand Down Expand Up @@ -165,12 +167,25 @@ def _serialize_custom_model_fields(
# This type is already stored in custom_types.
return version_name
fields = {}

for field_name, info in fr.model_fields.items():
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields

bases: list[tuple[str, str, Optional[str]]] = []
for type_ in fr.__mro__:
model_store_name = (
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
)
bases.append((type_.__name__, type_.__module__, model_store_name))

custom_types[version_name] = {
"_custom_types_schema_version": 2,
"fields": fields,
"bases": bases,
}
return version_name

@staticmethod
Expand All @@ -184,7 +199,6 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
Expand Down Expand Up @@ -228,26 +242,87 @@ def _split_subtypes(type_name: str) -> list[str]:
return subtypes

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
def _get_custom_type_attr(
attr_name: str,
type_name: str,
custom_types: dict[str, Any],
default: Any = None,
) -> Any:
custom_type_description = custom_types.get(type_name)

if custom_type_description is None:
raise SignalSchemaError(
f"cannot deserialize '{type_name}' from custom types"
)

# Backward compatibility with the old custom types schema version that didn't
# include bases, and only had fields.
if "_custom_types_schema_version" not in custom_type_description:
if attr_name == "fields":
return custom_type_description
return default
return custom_type_description[attr_name]

@staticmethod
def _get_or_create_custom_type(
type_name: str, custom_types: dict[str, Any]
) -> Optional[type]:
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
it based on the information from the custom types dict that includes fields and
bases."""
model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
if fr:
return fr

if type_name in custom_types:
fields = SignalSchema._get_custom_type_attr(
"fields", type_name, custom_types
)
bases = SignalSchema._get_custom_type_attr(
"bases", type_name, custom_types, default=[]
)
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
}

base_model = None
for base in bases:
_, _, model_store_name = base
if model_store_name:
model_name, version = ModelStore.parse_name_version(
model_store_name
)
base_model = ModelStore.get(model_name, version)
if base_model:
break

return create_feature_model(type_name, fields, base=base_model)

return None

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
"""Convert a string-based type back into a python type."""
type_name = type_name.strip()
if not type_name:
raise TypeError("Type cannot be empty")
raise ValueError("Type cannot be empty")
if type_name == "NoneType":
return None

bracket_idx = type_name.find("[")
subtypes: Optional[tuple[Optional[type], ...]] = None
if bracket_idx > -1:
if bracket_idx == 0:
raise TypeError("Type cannot start with '['")
raise ValueError("Type cannot start with '['")
close_bracket_idx = type_name.rfind("]")
if close_bracket_idx == -1:
raise TypeError("Unclosed square bracket when parsing type")
raise ValueError("Unclosed square bracket when parsing type")
if close_bracket_idx < bracket_idx:
raise TypeError("Square brackets are out of order when parsing type")
raise ValueError("Square brackets are out of order when parsing type")
if close_bracket_idx == bracket_idx + 1:
raise TypeError("Empty square brackets when parsing type")
raise ValueError("Empty square brackets when parsing type")
subtype_names = SignalSchema._split_subtypes(
type_name[bracket_idx + 1 : close_bracket_idx]
)
Expand All @@ -267,18 +342,10 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
return fr[subtypes] # type: ignore[index]
return fr # type: ignore[return-value]

model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
fr = SignalSchema._get_or_create_custom_type(type_name, custom_types)
if fr:
return fr

if type_name in custom_types:
fields = custom_types[type_name]
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
Expand Down Expand Up @@ -662,6 +729,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str:
stacklevel=2,
)
return "Any"
if ModelStore.is_pydantic(type_):
ModelStore.register(type_)
return ModelStore.get_name(type_)
return type_.__name__

@staticmethod
Expand Down
Loading

0 comments on commit a8e673f

Please sign in to comment.