Skip to content

Commit

Permalink
Use code execution status rather than chat message status.
Browse files Browse the repository at this point in the history
  • Loading branch information
EtiennePerot committed Oct 13, 2024
1 parent 0d13de2 commit c0937d0
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 141 deletions.
131 changes: 94 additions & 37 deletions open-webui/functions/run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ async def action(
valves = self.valves
debug = valves.DEBUG
emitter = EventEmitter(__event_emitter__, debug=debug)
execution_tracker: typing.Optional[CodeExecutionTracker] = None

update_check_error = None
update_check_notice = ""
Expand Down Expand Up @@ -194,6 +195,9 @@ async def action(
)

async def _fail(error_message, status="SANDBOX_ERROR"):
if execution_tracker is not None:
execution_tracker.set_error(error_message)
await emitter.code_execution(execution_tracker)
if debug:
await emitter.fail(
f"[DEBUG MODE] {error_message}; body={body}; valves=[{valves}]"
Expand All @@ -204,7 +208,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
await emitter.fail(error_message)
return json.dumps({"status": status, "output": error_message})

await emitter.status("Checking messages for code blocks...")
if len(body.get("messages", ())) == 0:
return await _fail("No messages in conversation.", status="INVALID_INPUT")
last_message = body["messages"][-1]
Expand Down Expand Up @@ -271,7 +274,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
if self.valves.MAX_RAM_MEGABYTES != 0:
max_ram_bytes = self.valves.MAX_RAM_MEGABYTES * 1024 * 1024

await emitter.status("Checking if environment supports sandboxing...")
Sandbox.check_setup(
language=language,
auto_install_allowed=self.valves.AUTO_INSTALL,
Expand All @@ -282,7 +284,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
await emitter.status("Auto-installing gVisor...")
Sandbox.install_runsc()

await emitter.status("Initializing sandbox configuration...")
status = "UNKNOWN"
output = None
generated_files = []
Expand All @@ -298,6 +299,12 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
code = code.removeprefix("bash")
code = code.removeprefix("sh")
code = code.strip()
language_title = language.title()
execution_tracker = CodeExecutionTracker(
name=f"{language_title} code block", code=code, language=language
)
await emitter.clear_status()
await emitter.code_execution(execution_tracker)

with tempfile.TemporaryDirectory(prefix="sandbox_") as tmp_dir:
sandbox_storage_path = os.path.join(tmp_dir, "storage")
Expand All @@ -314,23 +321,25 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
persistent_home_dir=sandbox_storage_path,
)

await emitter.status(
f"Running {language_title} code in gVisor sandbox..."
)
try:
result = sandbox.run()
except Sandbox.ExecutionTimeoutError as e:
await emitter.fail(
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
)
execution_tracker.set_error(
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
)
status = "TIMEOUT"
output = e.stderr
except Sandbox.InterruptedExecutionError as e:
await emitter.fail("Code used too many resources")
execution_tracker.set_error("Code used too many resources")
status = "INTERRUPTED"
output = e.stderr
except Sandbox.CodeExecutionError as e:
await emitter.fail(f"{language_title}: {e}")
execution_tracker.set_error(f"{language_title}: {e}")
status = "ERROR"
output = e.stderr
else:
Expand All @@ -354,14 +363,14 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
status = "STORAGE_ERROR"
output = f"Storage quota exceeded: {e}"
await emitter.fail(output)
if status == "OK":
await emitter.status(
status="complete",
done=True,
description=f"{language_title} code executed successfully.",
)
for generated_file in generated_files:
execution_tracker.add_file(
name=generated_file.name, url=generated_file.url
)
if output:
output = output.strip()
execution_tracker.set_output(output)
await emitter.code_execution(execution_tracker)
if debug:
per_file_logs = {}

Expand Down Expand Up @@ -935,6 +944,9 @@ async def action(
)


# fmt: off


class EventEmitter:
"""
Helper wrapper for OpenWebUI event emissions.
Expand All @@ -948,27 +960,32 @@ def __init__(
self.event_emitter = event_emitter
self._debug = debug
self._status_prefix = None
self._emitted_status = False

def set_status_prefix(self, status_prefix):
self._status_prefix = status_prefix

async def _emit(self, typ, data):
async def _emit(self, typ, data, twice):
if self._debug:
print(f"Emitting {typ} event: {data}", file=sys.stderr)
if not self.event_emitter:
return None
maybe_future = self.event_emitter(
{
"type": typ,
"data": data,
}
)
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
return await maybe_future
result = None
for i in range(2 if twice else 1):
maybe_future = self.event_emitter(
{
"type": typ,
"data": data,
}
)
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
result = await maybe_future
return result

async def status(
self, description="Unknown state", status="in_progress", done=False
):
self._emitted_status = True
if self._status_prefix is not None:
description = f"{self._status_prefix}{description}"
await self._emit(
Expand All @@ -978,29 +995,33 @@ async def status(
"description": description,
"done": done,
},
twice=not done and len(description) <= 1024,
)
if not done and len(description) <= 1024:
# Emit it again; Open WebUI does not seem to flush this reliably.
# Only do it for relatively small statuses; when debug mode is enabled,
# this can take up a lot of space.
await self._emit(
"status",
{
"status": status,
"description": description,
"done": done,
},
)

async def fail(self, description="Unknown error"):
await self.status(description=description, status="error", done=True)

async def clear_status(self):
if not self._emitted_status:
return
self._emitted_status = False
await self._emit(
"status",
{
"status": "complete",
"description": "",
"done": True,
},
twice=True,
)

async def message(self, content):
await self._emit(
"message",
{
"content": content,
},
twice=False,
)

async def citation(self, document, metadata, source):
Expand All @@ -1011,16 +1032,51 @@ async def citation(self, document, metadata, source):
"metadata": metadata,
"source": source,
},
twice=False,
)

async def code_execution_result(self, output):
async def code_execution(self, code_execution_tracker):
await self._emit(
"code_execution_result",
"citation", code_execution_tracker._citation_data(), twice=True
)


class CodeExecutionTracker:
def __init__(self, name, code, language):
self._uuid = str(uuid.uuid4())
self.name = name
self.code = code
self.language = language
self._result = {}

def set_error(self, error):
self._result["error"] = error

def set_output(self, output):
self._result["output"] = output

def add_file(self, name, url):
if "files" not in self._result:
self._result["files"] = []
self._result["files"].append(
{
"output": output,
},
"name": name,
"url": url,
}
)

def _citation_data(self):
data = {
"type": "code_execution",
"uuid": self._uuid,
"name": self.name,
"code": self.code,
"language": self.language,
}
if "output" in self._result or "error" in self._result:
data["result"] = self._result
return data


class Sandbox:
"""
Expand Down Expand Up @@ -3255,6 +3311,7 @@ def get_newer_version(cls) -> typing.Optional[str]:


UpdateCheck.init_from_frontmatter(os.path.abspath(__file__))
# fmt: on


_SAMPLE_BASH_INSTRUCTIONS = (
Expand Down
Loading

0 comments on commit c0937d0

Please sign in to comment.