Skip to content

Commit

Permalink
Merge pull request #95 from dapper91/pydantic-2
Browse files Browse the repository at this point in the history
- pydantic 2 support added.
  • Loading branch information
dapper91 authored Sep 26, 2023
2 parents a79c33a + 7c9a1af commit 440f686
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 22 deletions.
12 changes: 6 additions & 6 deletions pjrpc/server/specs/extractors/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def extract_params_schema(self, method: MethodType, exclude: Iterable[str] = ())
)

params_model = pd.create_model('RequestModel', **field_definitions)
model_schema = params_model.schema(ref_template=self._ref_template)
model_schema = params_model.model_json_schema(ref_template=self._ref_template)

parameters_schema = {}
for param_name, param_schema in model_schema['properties'].items():
Expand All @@ -45,7 +45,7 @@ def extract_params_schema(self, method: MethodType, exclude: Iterable[str] = ())
description=param_schema.get('description', UNSET),
deprecated=param_schema.get('deprecated', UNSET),
required=required,
definitions=model_schema.get('definitions'),
definitions=model_schema.get('$defs'),
)

return parameters_schema
Expand All @@ -60,8 +60,8 @@ def extract_result_schema(self, method: MethodType) -> Schema:
else:
return_annotation = result.return_annotation

result_model = pd.create_model('ResultModel', result=(return_annotation, pd.fields.Undefined))
model_schema = result_model.schema(ref_template=self._ref_template)
result_model = pd.create_model('ResultModel', result=(return_annotation, ...))
model_schema = result_model.model_json_schema(ref_template=self._ref_template)

result_schema = model_schema['properties']['result']
required = 'result' in model_schema.get('required', [])
Expand Down Expand Up @@ -95,7 +95,7 @@ def extract_errors_schema(
field_definitions[field_name] = (annotation, getattr(error, field_name, ...))

result_model = pd.create_model(error.message, **field_definitions)
model_schema = result_model.schema(ref_template=self._ref_template)
model_schema = result_model.model_json_schema(ref_template=self._ref_template)

data_schema = model_schema['properties'].get('data', UNSET)
required = 'data' in model_schema.get('required', [])
Expand All @@ -109,7 +109,7 @@ def extract_errors_schema(
title=error.message,
description=inspect.cleandoc(error.__doc__) if error.__doc__ is not None else UNSET,
deprecated=model_schema.get('deprecated', UNSET),
definitions=model_schema.get('definitions'),
definitions=model_schema.get('$defs'),
),
)
return errors_schema
Expand Down
6 changes: 3 additions & 3 deletions pjrpc/server/validators/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, coerce: bool = True, **config_args: Any):
config_args.setdefault('extra', 'forbid')

# https://pydantic-docs.helpmanual.io/usage/model_config/
self._model_config = type('ModelConfig', (pydantic.BaseConfig,), config_args)
self._model_config = pydantic.ConfigDict(**config_args)

def validate_method(
self, method: Callable[..., Any], params: Optional['JsonRpcParams'], exclude: Iterable[str] = (), **kwargs: Any,
Expand All @@ -43,15 +43,15 @@ def validate_method(
signature = self.signature(method, tuple(exclude))
schema = self.build_validation_schema(signature)

params_model = pydantic.create_model(method.__name__, **schema, __config__=self._model_config)
params_model = pydantic.create_model(method.__name__, **schema, model_config=self._model_config)

bound_params = self.bind(signature, params)
try:
obj = params_model(**bound_params.arguments)
except pydantic.ValidationError as e:
raise base.ValidationError(*e.errors()) from e

return {attr: getattr(obj, attr) for attr in obj.__fields_set__} if self._coerce else bound_params.arguments
return {attr: getattr(obj, attr) for attr in obj.model_fields} if self._coerce else bound_params.arguments

@ft.lru_cache(maxsize=None)
def build_validation_schema(self, signature: inspect.Signature) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jsonschema = {version = ">=3.0,<4.0", optional = true}
kombu = { version = ">=5.1", optional = true }
markupsafe = { version = "==2.0.1", optional = true }
openapi-ui-bundles = { version = ">=0.1", optional = true }
pydantic = {version = ">=1.7.0,<2.0", optional = true}
pydantic = {version = ">=2.0", optional = true}
requests = { version = ">=2.0", optional = true }
starlette = { version = ">=0.25.0", optional = true }
werkzeug = { version = ">=2.0", optional = true}
Expand Down
43 changes: 31 additions & 12 deletions tests/server/resources/openapi-1.json
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,15 @@
]
},
"result": {
"title": "Result",
"type": "string",
"nullable": "true"
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Result"
}
},
"required": [
Expand Down Expand Up @@ -532,12 +538,20 @@
"type": "object",
"properties": {
"param1": {
"title": "Param1",
"type": "number"
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Param1"
},
"param2": {
"title": "Param2",
"default": 1,
"title": "Param2",
"type": "integer"
}
}
Expand Down Expand Up @@ -876,8 +890,7 @@
]
},
"result": {
"title": "Result",
"nullable": "true"
"title": "Result"
}
},
"required": [
Expand Down Expand Up @@ -1017,8 +1030,7 @@
]
},
"result": {
"title": "Result",
"nullable": "true"
"title": "Result"
}
},
"required": [
Expand Down Expand Up @@ -1190,9 +1202,16 @@
"type": "string"
},
"field2": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"title": "Field2",
"default": 1,
"type": "integer"
"default": 1
},
"field3": {
"$ref": "#/components/schemas/SubModel"
Expand Down

0 comments on commit 440f686

Please sign in to comment.