Skip to content

Commit

Permalink
fix: patch annotations inside a container type (#5167)
Browse files Browse the repository at this point in the history
* fix: patch annotations inside a container

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Jan 6, 2025
1 parent 3fe9653 commit 425239d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
36 changes: 36 additions & 0 deletions src/_bentoml_sdk/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t

from pydantic._internal import _known_annotated_metadata
from pydantic._internal._typing_extra import is_annotated

from .typing_utils import get_args
from .typing_utils import get_origin
Expand Down Expand Up @@ -194,3 +195,38 @@ def pathlib_prepare_pydantic_annotations(
# PIL image
pil_prepare_pydantic_annotations,
]

SUPPORTED_CONTAINER_TYPES = [
t.Union,
list,
t.List,
dict,
t.Dict,
t.AsyncGenerator,
t.AsyncIterable,
t.AsyncIterator,
t.Generator,
t.Iterable,
t.Iterator,
]


def patch_annotation(annotation: t.Any, model_config: ConfigDict) -> t.Any:
import typing_extensions as te

origin, args = te.get_origin(annotation), te.get_args(annotation)
if origin in SUPPORTED_CONTAINER_TYPES:
patched_args = [patch_annotation(arg, model_config) for arg in args]
return origin[tuple(patched_args)]

if is_annotated(annotation):
source, *annotations = args
else:
source = annotation
annotations = []
for method in CUSTOM_PREPARE_METHODS:
result = method(source, annotations, model_config)
if result is None:
continue
return t.Annotated[(result[0], *result[1])] # type: ignore
return annotation
14 changes: 2 additions & 12 deletions src/_bentoml_sdk/io_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,10 @@ def mime_type(cls) -> str:

@classmethod
def __get_pydantic_core_schema__(cls: type[BaseModel], source, handler):
from ._pydantic import CUSTOM_PREPARE_METHODS
from ._pydantic import patch_annotation

for _, info in cls.model_fields.items():
if is_annotated(info.annotation):
origin, *args = get_args(info.annotation)
else:
origin = info.annotation
args = []
for method in CUSTOM_PREPARE_METHODS:
result = method(origin, args, cls.model_config)
if result is None:
continue
info.annotation = t.Annotated[(result[0], *result[1])] # type: ignore
break
info.annotation = patch_annotation(info.annotation, cls.model_config)

return super().__get_pydantic_core_schema__(source, handler)

Expand Down

0 comments on commit 425239d

Please sign in to comment.