Skip to content

Commit

Permalink
feat: refactor tool calling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
CNSeniorious000 committed Dec 30, 2024
1 parent 2f2f16e commit ebe8b45
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 27 deletions.
25 changes: 16 additions & 9 deletions src/logic/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import ensure_future, gather
from asyncio import Task, ensure_future, gather
from functools import cached_property
from json import dumps
from typing import cast

Expand All @@ -20,6 +21,18 @@ class TypedContext(ChainContext):
partial = True
parsed: Output = {}

@cached_property
def _tasks(self) -> list[Task]:
return []

def call_tools(self):
actions = self.parsed.get("actions", [])
for action in actions[len(self._tasks) : len(actions) - self.partial]:
task = ensure_future(call_tool(action["name"], body := action.get("body", {})))
self._tasks.append(task)
print(f"start <{action['name']}> with {body}")
return self._tasks


main = Node(load_template("main"), TypedContext({"tools": tools}))

Expand All @@ -45,7 +58,7 @@ async def collect_results(context: TypedContext):
if not actions:
return

results = await gather(*(call_tool(i["name"], i.get("body", {})) for i in actions))
results = await gather(*context.call_tools())

messages = cast(list[Message], context["messages"])

Expand Down Expand Up @@ -83,10 +96,4 @@ def parse_json(context: TypedContext):
finally:
context["parsed"] = context.parsed


@main.mid_process
async def run_tools(context: TypedContext):
if actions := context.parsed.get("actions"):
for action in actions[:-1] if context.partial else actions:
ensure_future(call_tool(action["name"], body := action.get("body", {})))
print(f"start <{action['name']}> with {body}")
context.call_tools()
20 changes: 2 additions & 18 deletions src/logic/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,6 @@
tool_map: dict[str, AbstractTool] = {tool.name: tool for tool in tools}


states: dict[tuple[str, str], Future | JSON] = {}


def hashable(json: dict):
return dumps(json, sort_keys=True)


async def call_tool(name: str, body: Context):
key = (name, hashable(body))

if key not in states:
job = states[key] = Future()
tool = tool_map[name]
result = await resolve(tool(**body))
job.set_result(result)
return result

job = states.get(key)
return cast(JSON, await job if isinstance(job, Future) else job)
tool = tool_map[name]
return await resolve(tool(**body))

0 comments on commit ebe8b45

Please sign in to comment.