Skip to content

Commit

Permalink
Merge pull request #93 from shouples/djs/func-schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk authored Sep 30, 2023
2 parents ac0b34d + 174d9ab commit e3c3992
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 86 deletions.
162 changes: 86 additions & 76 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,22 @@ class WhatTime(BaseModel):
import inspect
import json
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Type, Union, get_args, get_origin, overload

from pydantic import BaseModel
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Type,
Union,
get_args,
get_origin,
overload,
)

from pydantic import BaseModel, create_model

from .decorators import ChatlabMetadata

Expand Down Expand Up @@ -87,48 +100,10 @@ def is_union_type(t):
return get_origin(t) is Union


def process_type(annotation, is_required=True):
"""Determine the JSON schema type of a type annotation."""
origin = get_origin(annotation)
args = get_args(annotation)

if is_optional_type(annotation):
return process_type(args[0], is_required=False)

elif origin is Union:
types = [process_type(t, is_required)[0]["type"] for t in args if t is not type(None)] # noqa: E721
return {"type": types}, is_required

elif origin is list:
item_type = process_type(args[0], is_required)[0]["type"]
return {"type": "array", "items": {"type": item_type}}, is_required

elif origin is Literal:
values = get_args(annotation)
return {"type": "string", "enum": values}, is_required
class FunctionSchemaConfig:
"""Config used for model generation during function schema creation."""

elif issubclass(annotation, Enum):
values = [e.name for e in annotation]
return {"type": "string", "enum": values}, is_required

elif origin is dict:
return {"type": "object"}, is_required

elif annotation in ALLOWED_TYPES:
return {
"type": JSON_SCHEMA_TYPES[annotation],
}, is_required

else:
raise Exception(f"Type annotation must be a JSON serializable type ({ALLOWED_TYPES})")


def process_parameter(name, param):
"""Process a function parameter for use in a JSON schema."""
prop_schema, is_required = process_type(param.annotation, param.default == inspect.Parameter.empty)
if param.default != inspect.Parameter.empty:
prop_schema["default"] = param.default
return prop_schema, is_required
arbitrary_types_allowed = True


def generate_function_schema(
Expand All @@ -146,38 +121,63 @@ def generate_function_schema(
if not doc:
raise Exception("Only functions with docstrings can be registered")

schema = None
schema = {
"name": func_name,
"description": doc,
"parameters": {},
}

if isinstance(parameter_schema, dict):
schema = parameter_schema
parameters = parameter_schema
elif parameter_schema is not None:
schema = parameter_schema.schema()
parameters = parameter_schema.schema()
else:
schema_properties = {}
required = []

# extract function parameters and their type annotations
sig = inspect.signature(function)

fields = {}
for name, param in sig.parameters.items():
prop_schema, is_required = process_parameter(name, param)
schema_properties[name] = prop_schema
if is_required:
required.append(name)

schema = {"type": "object", "properties": {}, "required": []}
if len(schema_properties) > 0:
schema = {
"type": "object",
"properties": schema_properties,
"required": required,
}

if schema is None:
raise Exception(f"Could not generate schema for function {func_name}")

return {
"name": func_name,
"description": doc,
"parameters": schema,
}
# skip 'self' for class methods
if name == "self":
continue

# determine type annotation
if param.annotation == inspect.Parameter.empty:
# no annotation, raise instead of falling back to Any
raise Exception(
f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation"
)
type_annotation = param.annotation

# get the default value, otherwise set as required
default_value = ...
if param.default != inspect.Parameter.empty:
default_value = param.default

fields[name] = (type_annotation, default_value)

# create the pydantic model and return its JSON schema to pass into the 'parameters' part of the
# function schema used by OpenAI
model = create_model(
function.__name__,
__config__=FunctionSchemaConfig,
**fields,
)
parameters: dict = model.schema()

if "properties" not in parameters:
parameters["properties"] = {}

# remove "title" since it's unused by OpenAI
parameters.pop("title", None)
for field_name in parameters["properties"].keys():
parameters["properties"][field_name].pop("title", None)

if "required" not in parameters:
parameters["required"] = []

schema["parameters"] = parameters
return schema


# Declare the type for the python hallucination
Expand Down Expand Up @@ -232,7 +232,9 @@ def __init__(self, python_hallucination_function: Optional[PythonHallucinationFu

self.python_hallucination_function = python_hallucination_function

def decorator(self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Callable:
def decorator(
self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Callable:
"""Create a decorator for registering functions with a schema."""

def decorator(function):
Expand All @@ -243,16 +245,22 @@ def decorator(function):

@overload
def register(
self, function: None = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...

@overload
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
def register(
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Dict:
...

def register(
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
self,
function: Optional[Callable] = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Union[Callable, Dict]:
"""Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function.
Expand Down Expand Up @@ -407,7 +415,9 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
parameters = json.loads(arguments)
# TODO: Validate parameters against schema
except json.JSONDecodeError:
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")
raise FunctionArgumentError(
f"Invalid Function call on {name}. Arguments must be a valid JSON object"
)

if function is None:
raise UnknownFunctionError(f"Function {name} is not registered")
Expand Down
Loading

0 comments on commit e3c3992

Please sign in to comment.