From 8af236e74911d4b624f27e3b19326d8e894c86bd Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Sun, 10 Sep 2023 19:43:30 +0200 Subject: [PATCH] feature: Support non-array common struct fields (#78) Building schema with upstream trunk was failing due to lacking support for non-array common struct fields. --- codegen/generate_schema.py | 41 +++++++++++++++++++++++++++++++++----- codegen/parser.py | 40 +++++++++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/codegen/generate_schema.py b/codegen/generate_schema.py index 9b215b3d..1fcce206 100644 --- a/codegen/generate_schema.py +++ b/codegen/generate_schema.py @@ -24,6 +24,8 @@ from .case import to_snake_case from .header_schema import get_header_schema_import from .parser import CommonStructArrayField +from .parser import CommonStructField +from .parser import CommonStructType from .parser import DataSchema from .parser import EntityArrayField from .parser import EntityField @@ -137,7 +139,7 @@ def format_default( def format_dataclass_field( - field_type: Primitive | PrimitiveArrayType | EntityType, + field_type: Primitive | PrimitiveArrayType | EntityType | CommonStructType, default: str | int | float | bool | None, optional: bool, custom_type: CustomTypeDef | None, @@ -166,6 +168,7 @@ def format_dataclass_field( field_kwargs["default"] = "()" elif default is not None: assert not isinstance(field_type, EntityType) + assert not isinstance(field_type, CommonStructType) field_kwargs["default"] = format_default( field_type, default, optional, custom_type ) @@ -310,7 +313,10 @@ def generate_entity_array_field( return f" {to_snake_case(field.name)}: tuple[{field.type}, ...]{field_call}\n" -def generate_entity_field(field: EntityField, version: int) -> str: +def generate_entity_field( + field: EntityField | CommonStructField, + version: int, +) -> str: field_call = format_dataclass_field( field_type=field.type, default=None, @@ -330,11 +336,28 @@ def generate_common_struct_array_field( ) -> str: field_call = format_array_field_call(field, version) return ( - f" {to_snake_case(field.name)}: tuple[{field.type.item_type.name}, ...]" + f" {to_snake_case(field.name)}: tuple[{field.type.struct.name}, ...]" f"{field_call}\n" ) +def generate_common_struct_field( + field: CommonStructField, + version: int, +) -> str: + field_call = format_dataclass_field( + field_type=field.type, + default=None, + optional=( + field.nullableVersions.matches(version) if field.nullableVersions else False + ), + custom_type=None, + tag=field.get_tag(version), + ignorable=field.ignorable, + ) + return f" {to_snake_case(field.name)}: {field.type.struct.name}{field_call}\n" + + seen = set[tuple[str, int]]() @@ -442,11 +465,19 @@ class {name}: elif isinstance(field, CommonStructArrayField): yield from generate_dataclass( schema=schema, - name=field.type.item_type.name, - fields=field.type.item_type.fields, + name=field.type.struct.name, + fields=field.type.struct.fields, version=version, ) class_fields.append(generate_common_struct_array_field(field, version)) + elif isinstance(field, CommonStructField): + yield from generate_dataclass( + schema=schema, + name=field.type.struct.name, + fields=field.type.struct.fields, + version=version, + ) + class_fields.append(generate_common_struct_field(field, version)) else: assert_never(field) diff --git a/codegen/parser.py b/codegen/parser.py index 44f8df68..8e6746c5 100644 --- a/codegen/parser.py +++ b/codegen/parser.py @@ -136,8 +136,18 @@ def parse_primitive_array_type(value: object) -> PrimitiveArrayType: yield parse_primitive_array_type +def parse_common_struct_reference(struct_name: object) -> CommonStruct: + if not isinstance(struct_name, str): + raise ValueError("Common struct reference must be str") + + try: + return structs_registry[struct_name] + except KeyError: + raise ValueError(f"No registered common struct named {struct_name!r}") from None + + class CommonStructArrayType(NamedTuple): - item_type: CommonStruct + struct: CommonStruct @classmethod def __get_validators__(cls) -> Iterator[Callable[[object], CommonStructArrayType]]: @@ -146,21 +156,25 @@ def parse_common_struct_array_type(value: object) -> CommonStructArrayType: raise ValueError("CommonStructArrayType must be str") if not value.startswith("[]"): raise ValueError("CommonStructArrayType must start with '[]'") - struct_name = value.removeprefix("[]") - - try: - struct = structs_registry[struct_name] - except KeyError: - raise ValueError( - f"No registered common struct named {struct_name!r}" - ) from None - + struct = parse_common_struct_reference(struct_name) return CommonStructArrayType(struct) yield parse_common_struct_array_type +class CommonStructType(NamedTuple): + struct: CommonStruct + + @classmethod + def __get_validators__(cls) -> Iterator[Callable[[object], CommonStructType]]: + def parse_common_struct_type(struct_name: object) -> CommonStructType: + struct = parse_common_struct_reference(struct_name) + return CommonStructType(struct) + + yield parse_common_struct_type + + class _BaseField(BaseModel): name: str versions: VersionRange @@ -222,7 +236,7 @@ def get_tag(self, version: int) -> int | None: # Defining this union before its members allows not having to call # EntityField.update_forward_refs(). -Field: TypeAlias = "PrimitiveField | PrimitiveArrayField | EntityArrayField | CommonStructArrayField | EntityField" +Field: TypeAlias = "PrimitiveField | PrimitiveArrayField | EntityArrayField | CommonStructArrayField | EntityField | CommonStructField" timedelta_names: Final = frozenset( { "timeoutMs", @@ -343,6 +357,10 @@ class CommonStructArrayField(_BaseField): type: CommonStructArrayType +class CommonStructField(_BaseField): + type: CommonStructType + + class EntityArrayField(_BaseField): type: EntityArrayType fields: tuple[Field, ...]