Skip to content

Commit

Permalink
[python-fastapi] support oneOf the pydantic v2 way
Browse files Browse the repository at this point in the history
* Support oneOf and anyOf schemas the pydantic v2 way by generating them as Unions.
* Generate model constructor that forcefully sets the discriminator field to ensure it is included in the marshalled representation.
  • Loading branch information
mgoltzsche committed Oct 19, 2024
1 parent 7bd1bc4 commit 1605ab8
Show file tree
Hide file tree
Showing 21 changed files with 218 additions and 341 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,8 @@ public Map<String, ModelsMap> postProcessAllModels(Map<String, ModelsMap> objs)
codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs));
}

propagateDiscriminatorValuesToProperties(processed);

// create circular import
for (String m : codegenModelMap.keySet()) {
createImportMapOfSet(m, codegenModelMap);
Expand Down Expand Up @@ -1046,6 +1048,52 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
return objs;
}

private void propagateDiscriminatorValuesToProperties(Map<String, ModelsMap> objMap) {
HashMap<String, CodegenModel> modelMap = new HashMap<>();
for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
modelMap.put("#/components/schemas/" + entry.getKey(), m.getModel());
}
}

for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
for (ModelMap m : entry.getValue().getModels()) {
CodegenModel model = m.getModel();
if (model.discriminator != null && !model.oneOf.isEmpty()) {
// Populate default, implicit discriminator values
for (String typeName : model.oneOf) {
ModelsMap obj = objMap.get(typeName);
if (obj == null) {
continue;
}
for (ModelMap m1 : obj.getModels()) {
for (CodegenProperty p : m1.getModel().vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = typeName;
}
}
}
}
// Populate explicit discriminator values from mapping, overwriting default values
if (model.discriminator.getMapping() != null) {
for (Map.Entry<String, String> discrEntry : model.discriminator.getMapping().entrySet()) {
CodegenModel resolved = modelMap.get(discrEntry.getValue());
if (resolved != null) {
for (CodegenProperty p : resolved.vars) {
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
p.isDiscriminator = true;
p.discriminatorValue = discrEntry.getKey();
}
}
}
}
}
}
}
}
}


/*
* Gets the pydantic type given a Codegen Property
Expand Down Expand Up @@ -2160,7 +2208,16 @@ private PythonType getType(CodegenProperty cp) {
}

private String finalizeType(CodegenProperty cp, PythonType pt) {
if (!cp.required || cp.isNullable) {
if (cp.isDiscriminator && cp.discriminatorValue != null) {
moduleImports.add("typing", "Literal");
PythonType literal = new PythonType("Literal");
String literalValue = '"'+escapeText(cp.discriminatorValue)+'"';
PythonType valueType = new PythonType(literalValue);
literal.addTypeParam(valueType);
literal.setDefaultValue(literalValue);
cp.setDefaultValue(literalValue);
pt = literal;
} else if (!cp.required || cp.isNullable) {
moduleImports.add("typing", "Optional");
PythonType opt = new PythonType("Optional");
opt.addTypeParam(pt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,174 +14,56 @@ import re # noqa: F401
{{/vendorExtensions.x-py-model-imports}}
from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict
from typing_extensions import Literal
from pydantic import StrictStr, Field
from pydantic import StrictStr, Field, RootModel
try:
from typing import Self
except ImportError:
from typing_extensions import Self

{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS = [{{#anyOf}}"{{.}}"{{^-last}}, {{/-last}}{{/anyOf}}]

class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}):
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}):
"""
{{{description}}}{{^description}}{{{classname}}}{{/description}}
"""

{{#composedSchemas.anyOf}}
# data type: {{{dataType}}}
{{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
{{/composedSchemas.anyOf}}
if TYPE_CHECKING:
actual_instance: Optional[Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}]] = None
else:
actual_instance: Any = None
any_of_schemas: List[str] = Literal[{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS]
root: Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}] = None

model_config = {
"validate_assignment": True,
"protected_namespaces": (),
}
{{#discriminator}}

discriminator_value_class_map: Dict[str, str] = {
{{#children}}
'{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}}
{{/children}}
}
{{/discriminator}}

def __init__(self, *args, **kwargs) -> None:
if args:
if len(args) > 1:
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
if kwargs:
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
super().__init__(actual_instance=args[0])
else:
super().__init__(**kwargs)

@field_validator('actual_instance')
def actual_instance_must_validate_anyof(cls, v):
{{#isNullable}}
if v is None:
return v

{{/isNullable}}
instance = {{{classname}}}.model_construct()
error_messages = []
{{#composedSchemas.anyOf}}
# validate data type: {{{dataType}}}
{{#isContainer}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
try:
instance.{{vendorExtensions.x-py-name}} = v
return v
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
if not isinstance(v, {{{dataType}}}):
error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`")
else:
return v

{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}
if error_messages:
# no match
raise ValueError("No match found when setting the actual_instance in {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return v

@classmethod
def from_dict(cls, obj: dict) -> Self:
return cls.from_json(json.dumps(obj))

@classmethod
def from_json(cls, json_str: str) -> Self:
"""Returns the object represented by the json string"""
instance = cls.model_construct()
{{#isNullable}}
if json_str is None:
return instance

{{/isNullable}}
error_messages = []
{{#composedSchemas.anyOf}}
{{#isContainer}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isContainer}}
{{^isContainer}}
{{#isPrimitiveType}}
# deserialize data into {{{dataType}}}
try:
# validation
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
# assign value to actual_instance
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{^isPrimitiveType}}
# {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
try:
instance.actual_instance = {{{dataType}}}.from_json(json_str)
return instance
except (ValidationError, ValueError) as e:
error_messages.append(str(e))
{{/isPrimitiveType}}
{{/isContainer}}
{{/composedSchemas.anyOf}}

if error_messages:
# no match
raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
else:
return instance
def to_str(self) -> str:
"""Returns the string representation of the model using alias"""
return pprint.pformat(self.model_dump(by_alias=True))

def to_json(self) -> str:
"""Returns the JSON representation of the actual instance"""
if self.actual_instance is None:
return "null"
"""Returns the JSON representation of the model using alias"""
return self.model_dump_json(by_alias=True, exclude_unset=True)

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_json()
@classmethod
def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a JSON string"""
return cls.from_dict(json.loads(json_str))

def to_dict(self) -> Dict[str, Any]:
"""Return the dictionary representation of the model using alias"""
to_dict = getattr(self.root, "to_dict", None)
if callable(to_dict):
return self.model_dump(by_alias=True, exclude_unset=True)
else:
return json.dumps(self.actual_instance)
# primitive type
return self.root

def to_dict(self) -> Dict:
"""Returns the dict representation of the actual instance"""
if self.actual_instance is None:
return "null"
@classmethod
def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
"""Create an instance of {{{classname}}} from a dict"""
if obj is None:
return None

to_json = getattr(self.actual_instance, "to_json", None)
if callable(to_json):
return self.actual_instance.to_dict()
else:
# primitive type
return self.actual_instance
if not isinstance(obj, dict):
return cls.model_validate(obj)

def to_str(self) -> str:
"""Returns the string representation of the actual instance"""
return pprint.pformat(self.model_dump())
return cls.parse_obj(obj)

{{#vendorExtensions.x-py-postponed-model-imports.size}}
{{#vendorExtensions.x-py-postponed-model-imports}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}
{{/isAdditionalPropertiesTrue}}
}

def __init__(self, *a, **kw):
super().__init__(*a, **kw)
{{#vars}}
{{#isDiscriminator}}
self.{{name}} = self.{{name}}
{{/isDiscriminator}}
{{/vars}}

def to_str(self) -> str:
"""Returns the string representation of the model using alias"""
Expand Down
Loading

0 comments on commit 1605ab8

Please sign in to comment.