Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(signal schema): serialize base classes for custom types #777

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 89 additions & 27 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
Final,
List,
Literal,
Mapping,
Optional,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model
from sqlalchemy import ColumnElement
from typing_extensions import Literal as LiteralEx

Expand Down Expand Up @@ -85,8 +86,31 @@ def __init__(self, method: str, field):
)


class CustomType(BaseModel):
schema_version: int = Field(ge=1, le=2, strict=True)
name: str
fields: dict[str, str]
bases: list[tuple[str, str, Optional[str]]]

@classmethod
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
version = data.get("schema_version", 1)

if version == 1:
data = {
"schema_version": 1,
"name": type_name,
"fields": data,
"bases": [],
}

return cls(**data)


def create_feature_model(
name: str, fields: dict[str, Union[type, tuple[type, Any]]]
name: str,
fields: Mapping[str, Union[type, None, tuple[type, Any]]],
base: Optional[type] = None,
) -> type[BaseModel]:
"""
This gets or returns a dynamic feature model for use in restoring a model
Expand All @@ -98,7 +122,7 @@ def create_feature_model(
name = name.replace("@", "_")
return create_model(
name,
__base__=DataModel, # type: ignore[call-overload]
__base__=base or DataModel, # type: ignore[call-overload]
# These are tuples for each field of: annotation, default (if any)
**{
field_name: anno if isinstance(anno, tuple) else (anno, None)
Expand Down Expand Up @@ -156,7 +180,7 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
return SignalSchema(signals)

@staticmethod
def _serialize_custom_model_fields(
def _serialize_custom_model(
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
) -> str:
"""This serializes any custom type information to the provided custom_types
Expand All @@ -165,12 +189,23 @@ def _serialize_custom_model_fields(
# This type is already stored in custom_types.
return version_name
fields = {}

for field_name, info in fr.model_fields.items():
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields

bases: list[tuple[str, str, Optional[str]]] = []
for type_ in fr.__mro__:
model_store_name = (
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
)
bases.append((type_.__name__, type_.__module__, model_store_name))

ct = CustomType(schema_version=2, name=version_name, fields=fields, bases=bases)
custom_types[version_name] = ct.model_dump()

return version_name

@staticmethod
Expand All @@ -184,15 +219,12 @@ def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
type_name = st_version_name
# Save this type to custom_types.
SignalSchema._serialize_custom_model_fields(
st_version_name, st, custom_types
)
SignalSchema._serialize_custom_model(st_version_name, st, custom_types)
return type_name

def serialize(self) -> dict[str, Any]:
Expand All @@ -215,39 +247,74 @@ def _split_subtypes(type_name: str) -> list[str]:
depth += 1
elif c == "]":
if depth == 0:
raise TypeError(
raise ValueError(
"Extra closing square bracket when parsing subtype list"
)
depth -= 1
elif c == "," and depth == 0:
subtypes.append(type_name[start:i].strip())
start = i + 1
if depth > 0:
raise TypeError("Unclosed square bracket when parsing subtype list")
raise ValueError("Unclosed square bracket when parsing subtype list")
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
def _deserialize_custom_type(
type_name: str, custom_types: dict[str, Any]
) -> Optional[type]:
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
it based on the information from the custom types dict that includes fields and
bases."""
model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
if fr:
return fr

if type_name in custom_types:
ct = CustomType.deserialize(custom_types[type_name], type_name)

fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in ct.fields.items()
}

base_model = None
for base in ct.bases:
_, _, model_store_name = base
if model_store_name:
model_name, version = ModelStore.parse_name_version(
model_store_name
)
base_model = ModelStore.get(model_name, version)
if base_model:
break

return create_feature_model(type_name, fields, base=base_model)

return None

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
"""Convert a string-based type back into a python type."""
type_name = type_name.strip()
if not type_name:
raise TypeError("Type cannot be empty")
raise ValueError("Type cannot be empty")
if type_name == "NoneType":
return None

bracket_idx = type_name.find("[")
subtypes: Optional[tuple[Optional[type], ...]] = None
if bracket_idx > -1:
if bracket_idx == 0:
raise TypeError("Type cannot start with '['")
raise ValueError("Type cannot start with '['")
close_bracket_idx = type_name.rfind("]")
if close_bracket_idx == -1:
raise TypeError("Unclosed square bracket when parsing type")
raise ValueError("Unclosed square bracket when parsing type")
if close_bracket_idx < bracket_idx:
raise TypeError("Square brackets are out of order when parsing type")
raise ValueError("Square brackets are out of order when parsing type")
if close_bracket_idx == bracket_idx + 1:
raise TypeError("Empty square brackets when parsing type")
raise ValueError("Empty square brackets when parsing type")
subtype_names = SignalSchema._split_subtypes(
type_name[bracket_idx + 1 : close_bracket_idx]
)
Expand All @@ -267,18 +334,10 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
return fr[subtypes] # type: ignore[index]
return fr # type: ignore[return-value]

model_name, version = ModelStore.parse_name_version(type_name)
fr = ModelStore.get(model_name, version)
fr = SignalSchema._deserialize_custom_type(type_name, custom_types)
if fr:
return fr

if type_name in custom_types:
fields = custom_types[type_name]
fields = {
field_name: SignalSchema._resolve_type(field_type_str, custom_types)
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
Expand Down Expand Up @@ -317,7 +376,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema":
stacklevel=2,
)
continue
except TypeError as err:
except ValueError as err:
raise SignalSchemaError(
f"cannot deserialize '{signal}': {err}"
) from err
Expand Down Expand Up @@ -662,6 +721,9 @@ def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str:
stacklevel=2,
)
return "Any"
if ModelStore.is_pydantic(type_):
ModelStore.register(type_)
return ModelStore.get_name(type_)
return type_.__name__

@staticmethod
Expand Down
Loading
Loading