Skip to content

Commit

Permalink
Merge pull request #149 from rgbkrk/update
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
rgbkrk authored Sep 25, 2024
2 parents 9873fc9 + 0f7c479 commit 659de3e
Show file tree
Hide file tree
Showing 26 changed files with 2,105 additions and 1,749 deletions.
6 changes: 2 additions & 4 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,14 @@ def register(
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...
) -> Callable: ...

@overload
def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionDefinition:
...
) -> FunctionDefinition: ...

def register(
self,
Expand Down
8 changes: 6 additions & 2 deletions chatlab/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@
"""


from typing import Callable, Optional

from pydantic import BaseModel


class ChatlabMetadata(BaseModel):
"""ChatLab metadata for a function."""

expose_exception_to_llm: bool = True
render: Optional[Callable] = None
bubble_exceptions: bool = False


def bubble_exceptions(func):
if not hasattr(func, "chatlab_metadata"):
func.chatlab_metadata = ChatlabMetadata()
Expand All @@ -51,6 +53,7 @@ def bubble_exceptions(func):
func.chatlab_metadata.bubble_exceptions = True
return func


def expose_exception_to_llm(func):
"""Expose exceptions from calling the function to the LLM.
Expand Down Expand Up @@ -107,6 +110,7 @@ def store_knowledge_graph(kg: KnowledgeGraph, comment: str = "Knowledge Graph"):
chat.register(store_knowledge_graph)
'''


def incremental_display(render_func: Callable):
def decorator(func):
if not hasattr(func, "chatlab_metadata"):
Expand All @@ -118,5 +122,5 @@ def decorator(func):

func.chatlab_metadata.render = render_func
return func
return decorator

return decorator
3 changes: 1 addition & 2 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:


class HasGetToolArgumentsParameter(Protocol):
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam:
...
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam: ...


def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam:
Expand Down
1 change: 1 addition & 0 deletions chatlab/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class ChatModel(Enum):
"""Models available for use with chatlab."""

GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview"
GPT_4_0125_PREVIEW = "gpt-4-0125-preview"
GPT_4_1106_PREVIEW = "gpt-4-1106-preview"
Expand Down
14 changes: 9 additions & 5 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,14 @@ def register(
self,
function: None = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Callable:
...
) -> Callable: ...

@overload
def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionDefinition:
...
) -> FunctionDefinition: ...

def register(
self,
Expand Down Expand Up @@ -438,7 +436,13 @@ def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManife

@property
def tools(self) -> Iterable[ChatCompletionToolParam]:
return [{"type": "function", "function": adapt_function_definition(f)} for f in self.__schemas.values()]
return [
ChatCompletionToolParam(
type="function",
function=adapt_function_definition(f), # type: ignore
)
for f in self.__schemas.values()
]

async def call(self, name: str, arguments: Optional[str] = None) -> Any:
"""Call a function by name with the given parameters."""
Expand Down
1 change: 0 additions & 1 deletion chatlab/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
"run_python",
"shell_functions",
]

1 change: 1 addition & 0 deletions chatlab/tools/_mediatypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Media types for rich output for LLMs and in-notebook."""

import json
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/colors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Let models pick and show color palettes to you."""

import hashlib
from typing import List, Optional
from pydantic import BaseModel, validator, Field
Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
You've been warned. Have fun and be safe!
"""

import asyncio
import os

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The in-IPython python code runner for ChatLab."""

from traceback import TracebackException
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions chatlab/tools/shell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Shell commands for ChatLab."""

import asyncio
import subprocess

Expand Down
7 changes: 2 additions & 5 deletions chatlab/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Views for ChatLab."""

from .assistant import AssistantMessageView
from .tools import ToolArguments, ToolCalled

__all__ = [
"AssistantMessageView",
"ToolArguments",
"ToolCalled"
]
__all__ = ["AssistantMessageView", "ToolArguments", "ToolCalled"]
27 changes: 16 additions & 11 deletions chatlab/views/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from IPython.display import display
from IPython.core.getipython import get_ipython

from instructor.dsl.partialjson import JSONParser

from jiter import from_json


class ToolArguments(AutoUpdate):
Expand Down Expand Up @@ -75,11 +74,11 @@ def update(self) -> None:

def render(self):
if self.custom_render is not None:
# We use the same definition as was in the original function
try:
parser = JSONParser()
possible_args = parser.parse(self.arguments)

possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
except Exception:
return None
try:
Model = extract_model_from_function(self.name, self.custom_render)
# model = Model.model_validate(possible_args)
model = Model(**possible_args)
Expand Down Expand Up @@ -110,13 +109,17 @@ def append_arguments(self, arguments: str):
def apply_result(self, result: str):
"""Replaces the existing display with a new one that shows the result of the tool being called."""
tc = ToolCalled(
id=self.id, name=self.name, arguments=self.arguments, result=result, display_id=self.display_id,
custom_render=self.custom_render
id=self.id,
name=self.name,
arguments=self.arguments,
result=result,
display_id=self.display_id,
custom_render=self.custom_render,
)
tc.update()
return tc

async def call(self, function_registry: FunctionRegistry) -> 'ToolCalled':
async def call(self, function_registry: FunctionRegistry) -> "ToolCalled":
"""Call the function and return a stack of messages for LLM and human consumption."""
function_name = self.name
function_args = self.arguments
Expand Down Expand Up @@ -185,9 +188,11 @@ def render(self):
if self.custom_render is not None:
# We use the same definition as was in the original function
try:
parser = JSONParser()
possible_args = parser.parse(self.arguments)
possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
except Exception:
return None

try:
Model = extract_model_from_function(self.name, self.custom_render)
# model = Model.model_validate(possible_args)
model = Model(**possible_args)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"outputs": [],
"source": [
"from datetime import datetime\n",
"from pytz import timezone, all_timezones, utc\n",
"from pytz import timezone, all_timezones\n",
"from typing import Optional\n",
"from pydantic import BaseModel\n",
"\n",
Expand Down
Loading

0 comments on commit 659de3e

Please sign in to comment.