From 0d13de290397a6fdd699880c10cb6e6253a475f1 Mon Sep 17 00:00:00 2001 From: Etienne Perot Date: Sat, 12 Oct 2024 23:30:51 -0700 Subject: [PATCH] Add support to run multiple snippets of code in the same sandbox. This is a necessary step to implement issue #15. --- open-webui/functions/run_code.py | 300 ++++++++++++++++++---------- open-webui/tools/run_code.py | 300 ++++++++++++++++++---------- src/openwebui/functions/run_code.py | 3 +- src/openwebui/tools/run_code.py | 3 +- src/safecode/sandbox.py | 297 +++++++++++++++++---------- 5 files changed, 580 insertions(+), 323 deletions(-) diff --git a/open-webui/functions/run_code.py b/open-webui/functions/run_code.py index 1187a77..5ae3bd3 100644 --- a/open-webui/functions/run_code.py +++ b/open-webui/functions/run_code.py @@ -305,8 +305,7 @@ async def _fail(error_message, status="SANDBOX_ERROR"): sandbox = Sandbox( tmp_dir=tmp_dir, - language=language, - code=code, + snippets=((language, code),), debug=debug, networking_allowed=valves.NETWORKING_ALLOWED, max_runtime_seconds=valves.MAX_RUNTIME_SECONDS, @@ -1192,6 +1191,10 @@ class Sandbox: # Environment variable used to detect interpreter re-execution. _MARKER_ENVIRONMENT_VARIABLE = "__CODE_EXECUTION_STAGE" + # Re-execution stages. + _STAGE_SANDBOX = "SANDBOX" + _STAGE_SNIPPET = "SNIPPET" + # libc bindings. # Populated using `_libc`. _LIBC = None @@ -2485,9 +2488,15 @@ def main(cls): cls._SelfFile.init() if cls._MARKER_ENVIRONMENT_VARIABLE not in os.environ: return - directives = json.load(sys.stdin) try: - result = cls(**directives["settings"])._run() + directives = json.load(sys.stdin) + sandbox = cls(**directives["settings"]) + if directives["stage"] == cls._STAGE_SANDBOX: + result = sandbox._run() + elif directives["stage"] == cls._STAGE_SNIPPET: + result = sandbox._run_snippets() + else: + raise ValueError(f"Invalid stage in directives: {directives}") except Exception as e: exception_info = { "name": e.__class__.__name__, @@ -2527,8 +2536,7 @@ def main(cls): def __init__( self, tmp_dir: str, - language: str, - code: str, + snippets: list[tuple], debug: bool, networking_allowed: bool, max_runtime_seconds: int, @@ -2540,8 +2548,7 @@ def __init__( Constructor. :param tmp_dir: Temporary directory exclusive to this sandbox. Must outlive the Sandbox object. - :param language: The language of the code; must be one of `SUPPORTED_LANGUAGES`. - :param code: Arbitrary code that needs to run in the sandbox. + :param snippets: A list of 2-tuples (language, code) to run inside the sandbox. :param debug: Whether or not to enable debug-level logging for the sandbox. :param networking_allowed: Whether the code should be given access to the network. :param max_runtime_seconds: How long the code should be allowed to run, in seconds. @@ -2552,8 +2559,7 @@ def __init__( self._init( { "tmp_dir": tmp_dir, - "language": language, - "code": code, + "snippets": snippets, "debug": debug, "networking_allowed": networking_allowed, "max_runtime_seconds": max_runtime_seconds, @@ -2571,8 +2577,7 @@ def _init(self, settings): self._logs_path = os.path.join(self._tmp_dir, "logs") self._gotmp_dir = os.path.join(self._tmp_dir, "gotmp") self._sandbox_shared_path = os.path.join(self._tmp_dir, "sandbox") - self._language = self._settings["language"] - self._code = self._settings["code"] + self._snippets = self._settings["snippets"] self._debug = self._settings["debug"] self._networking_allowed = self._settings["networking_allowed"] self._max_runtime_seconds = self._settings["max_runtime_seconds"] @@ -2634,17 +2639,12 @@ def _setup_sandbox(self): else: raise e.__class__(f"{e}; {switcheroo_status}") - # Locate the interpreter to use. - interpreter_path = sys.executable - if self._language == self.LANGUAGE_BASH: - interpreter_path = shutil.which("bash") - if interpreter_path is None: - raise RuntimeError("Interpreter not found") + # Mount the Python interpreter. oci_config["mounts"].append( { "type": "bind", - "source": interpreter_path, - "destination": interpreter_path, + "source": sys.executable, + "destination": sys.executable, "options": ["ro"], } ) @@ -2664,7 +2664,7 @@ def _setup_sandbox(self): ) # Create read-only empty directories. - for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(interpreter_path)]: + for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(sys.executable)]: rootfs_subdir = os.path.join(rootfs_path, d.removeprefix(os.path.sep)) os.makedirs(rootfs_subdir, mode=0o755, exist_ok=True) @@ -2715,7 +2715,7 @@ def _setup_sandbox(self): } ) - # Shared sandbox directory to propagate exit code and persistent files. + # Shared sandbox directory to propagate and persistent files. oci_config["mounts"].append( { "type": "bind", @@ -2724,6 +2724,8 @@ def _setup_sandbox(self): "options": ["rw"], } ) + with open(os.path.join(self._sandbox_shared_path, "self.py"), "w") as self_f: + self_f.write(self._SelfFile.contents()) if self._persistent_home_dir is not None: oci_config["mounts"].append( { @@ -2744,19 +2746,10 @@ def _setup_sandbox(self): passwd_f.write("user:x:1000:1000:user:/home/user:/bin/bash\n") # Generate command line to run in the sandbox. + oci_config["process"]["env"].append(f"{self._MARKER_ENVIRONMENT_VARIABLE}=1") self._sandboxed_command = [ - shutil.which("bash"), - "-c", - "; ".join( - ( - "echo OK > /sandbox/started", - f"{interpreter_path} /dev/stdin", - 'echo "$?" > /sandbox/.pre_exit_code || exit 1', - "if [[ -d /sandbox/persistent ]]; then cp -rd --one-file-system /home/user/. /sandbox/persistent/ || exit 2; fi", - "mv /sandbox/.pre_exit_code /sandbox/exit_code || exit 3", - "exit 0", - ) - ), + sys.executable, + "/sandbox/self.py", ] # Work around issue that gVisor does not preserve correct UID mappings when running as non-root user in the sandbox. @@ -2774,6 +2767,66 @@ def _setup_sandbox(self): with open(os.path.join(self._bundle_path, "config.json"), "w") as bundle_f: json.dump(oci_config, bundle_f, indent=2, sort_keys=True) + def _process_json_wrapped_result( + self, result: subprocess.CompletedProcess + ) -> subprocess.CompletedProcess: + """ + Process a `CompletedProcess` from a wrapped invocation. + + :param result: A `CompletedProcess` with stdout captured. + :return: A synthetic `CompletedProcess` from the JSON information in stdout. + :raises Sandbox.SandboxRuntimeException: If the JSON information cannot be interpreted. + :raises Sandbox.SandboxException: For any exception that is forwarded. + """ + if not result.stdout: + raise self.SandboxRuntimeException( + f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" + ) + try: + output = json.loads(result.stdout) + except json.decoder.JSONDecodeError as e: + raise self.SandboxRuntimeException( + f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" + ) + if "exception" in output: + class_name = output["exception"]["name"] + found_class = None + for ex_class in ( + self.PlatformNotSupportedException, + self.SandboxRuntimeException, + self.CodeExecutionError, + self.ExecutionTimeoutError, + self.InterruptedExecutionError, + self.GVisorNotInstalledException, + self.CorruptDownloadException, + self.EnvironmentNeedsSetupException, + self.ExecutionError, + self.SandboxException, + ): + if ex_class.__name__ == class_name: + found_class = ex_class + break + if found_class is None: + exception_str = output["exception"]["str"] + raise self.SandboxRuntimeException(f"{class_name}: {exception_str}") + raise found_class( + *output["exception"]["args"], **output["exception"]["kwargs"] + ) + if "result" not in output: + raise self.SandboxRuntimeException( + f"Invalid response from subprocess: {output}" + ) + return subprocess.CompletedProcess( + args=output["result"]["args"], + returncode=output["result"]["returncode"], + stdout=base64.b64decode(output["result"]["stdout"]).decode( + "utf-8", errors="replace" + ), + stderr=base64.b64decode(output["result"]["stderr"]).decode( + "utf-8", errors="replace" + ), + ) + def _run(self) -> subprocess.CompletedProcess: """ Spawn and wait for the sandbox. Runs in separate forked process. @@ -2802,6 +2855,12 @@ def _run(self) -> subprocess.CompletedProcess: ] runsc_env = os.environ.copy() runsc_env["TMPDIR"] = self._gotmp_dir + runsc_input = json.dumps( + { + "stage": self._STAGE_SNIPPET, + "settings": self._settings, + } + ) started_marker_path = os.path.join(self._sandbox_shared_path, "started") resource_monitor_cancel = self._switcheroo.monitor_cgroup_resources() try: @@ -2809,15 +2868,18 @@ def _run(self) -> subprocess.CompletedProcess: runsc_argv, env=runsc_env, preexec_fn=self._switcheroo.move_process_to_sandbox_leaf_cgroup_lambda(), - input=self._code + "\n", + input=runsc_input, text=True, capture_output=True, - timeout=self._max_runtime_seconds, + timeout=self._max_runtime_seconds + 3, check=True, ) except subprocess.TimeoutExpired as e: raise self.ExecutionTimeoutError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=126, cmd=self._sandboxed_command, output=e.stdout, @@ -2826,7 +2888,10 @@ def _run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: if os.path.isfile(started_marker_path): raise self.InterruptedExecutionError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=127, cmd=self._sandboxed_command, output=e.stdout, @@ -2858,31 +2923,94 @@ def process_log(filename, log_line): raise self.SandboxRuntimeException( "Sandbox failed to start up properly" ) - exit_code_path = os.path.join(self._sandbox_shared_path, "exit_code") - if not os.path.isfile(exit_code_path): + return self._process_json_wrapped_result(result) + finally: + if self._switcheroo is not None: + self._switcheroo.cleanup() + + def _run_snippets(self): + """ + Run all snippets in the sandbox. + This code is called from *within* the gVisor sandbox. + """ + with open("/sandbox/started", "wb") as started_f: + started_f.write(b"OK\n") + deadline = time.time() + self._max_runtime_seconds + last_result = None + overall_args = [] + overall_stdout = "" + overall_stderr = "" + if len(self._snippets) == 0: + raise self.SandboxRuntimeException("No code snippets to run") + for snippet in self._snippets: + if len(snippet) != 2: + raise self.SandboxRuntimeException(f"Invalid snippet: {snippet}") + language, code = snippet + if language not in self.SUPPORTED_LANGUAGES: + raise self.SandboxRuntimeException(f"Unsupported language: {language}") + interpreter_path = None + if language == self.LANGUAGE_BASH: + interpreter_path = shutil.which("bash") + elif language == self.LANGUAGE_PYTHON: + interpreter_path = sys.executable + if interpreter_path is None: raise self.SandboxRuntimeException( - "Sandbox failed to record an exit code" + f"Cannot find interpreter for language: {language}" + ) + cmd = [interpreter_path, "/dev/stdin"] + overall_args.append(" ".join(cmd)) + snippet_timeout = deadline - time.time() + if snippet_timeout <= 0.0: + raise self.ExecutionTimeoutError( + f"Code executed the deadline of {self._max_runtime_seconds} seconds" ) - with open(exit_code_path, "r") as exit_code_f: - exit_code_str = exit_code_f.read() try: - exit_code = int(exit_code_str.strip()) - except ValueError as e: - raise self.SandboxRuntimeException( - f"Sandbox recorded non-integer exit code: {e}" + snippet_result = subprocess.run( + cmd, + input=code + "\n", + text=True, + capture_output=True, + timeout=snippet_timeout, + check=True, + ) + except subprocess.TimeoutExpired as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" + raise self.ExecutionTimeoutError( + code=code, + returncode=126, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, ) - if exit_code != 0: + except subprocess.CalledProcessError as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" raise self.CodeExecutionError( - code=self._code, - returncode=exit_code, - cmd=self._sandboxed_command, - output=result.stdout, - stderr=result.stderr, + code=code, + returncode=e.returncode, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, ) - return result - finally: - if self._switcheroo is not None: - self._switcheroo.cleanup() + else: + last_result = snippet_result + overall_stdout += snippet_result.stdout or "" + overall_stderr += snippet_result.stderr or "" + assert last_result is not None, "Logic error" + if os.path.isdir("/sandbox/persistent"): + shutil.copytree( + "/home/user", + "/sandbox/persistent", + ignore_dangling_symlinks=True, + dirs_exist_ok=True, + ) + return subprocess.CompletedProcess( + args=["sh", "-c", "; ".join(overall_args)], + returncode=0, + stdout=overall_stdout, + stderr=overall_stderr, + ) def run(self) -> subprocess.CompletedProcess: """ @@ -2900,12 +3028,17 @@ def run(self) -> subprocess.CompletedProcess: reexec_f.write(self._SelfFile.contents()) new_env = os.environ.copy() new_env[self._MARKER_ENVIRONMENT_VARIABLE] = "1" - data = json.dumps({"settings": self._settings}) + directives = json.dumps( + { + "stage": self._STAGE_SANDBOX, + "settings": self._settings, + } + ) try: result = subprocess.run( (sys.executable, reexec_path), env=new_env, - input=data, + input=directives, text=True, capture_output=True, check=True, @@ -2913,54 +3046,7 @@ def run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: raise self.SandboxRuntimeException(f"{e} (stderr: {e.stderr})") else: - if not result.stdout: - raise self.SandboxRuntimeException( - f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" - ) - try: - output = json.loads(result.stdout) - except json.decoder.JSONDecodeError as e: - raise self.SandboxRuntimeException( - f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" - ) - if "exception" in output: - class_name = output["exception"]["name"] - found_class = None - for ex_class in ( - self.PlatformNotSupportedException, - self.SandboxRuntimeException, - self.CodeExecutionError, - self.ExecutionTimeoutError, - self.InterruptedExecutionError, - self.GVisorNotInstalledException, - self.CorruptDownloadException, - self.EnvironmentNeedsSetupException, - self.ExecutionError, - self.SandboxException, - ): - if ex_class.__name__ == class_name: - found_class = ex_class - break - if found_class is None: - exception_str = output["exception"]["str"] - raise self.SandboxException(f"{class_name}: {exception_str}") - raise found_class( - *output["exception"]["args"], **output["exception"]["kwargs"] - ) - if "result" not in output: - raise self.SandboxException( - f"Invalid response from subprocess: {output}" - ) - return subprocess.CompletedProcess( - args=output["result"]["args"], - returncode=output["result"]["returncode"], - stdout=base64.b64decode(output["result"]["stdout"]).decode( - "utf-8", errors="replace" - ), - stderr=base64.b64decode(output["result"]["stderr"]).decode( - "utf-8", errors="replace" - ), - ) + return self._process_json_wrapped_result(result) def debug_logs(self, write_fn: typing.Callable[[str, str], typing.Any]): """ diff --git a/open-webui/tools/run_code.py b/open-webui/tools/run_code.py index 17f2270..6707529 100644 --- a/open-webui/tools/run_code.py +++ b/open-webui/tools/run_code.py @@ -257,8 +257,7 @@ async def _fail(error_message, status="SANDBOX_ERROR"): with tempfile.TemporaryDirectory(prefix="sandbox_") as tmp_dir: sandbox = Sandbox( tmp_dir=tmp_dir, - language=language, - code=code, + snippets=((language, code),), debug=debug, networking_allowed=valves.NETWORKING_ALLOWED, max_runtime_seconds=valves.MAX_RUNTIME_SECONDS, @@ -632,6 +631,10 @@ class Sandbox: # Environment variable used to detect interpreter re-execution. _MARKER_ENVIRONMENT_VARIABLE = "__CODE_EXECUTION_STAGE" + # Re-execution stages. + _STAGE_SANDBOX = "SANDBOX" + _STAGE_SNIPPET = "SNIPPET" + # libc bindings. # Populated using `_libc`. _LIBC = None @@ -1925,9 +1928,15 @@ def main(cls): cls._SelfFile.init() if cls._MARKER_ENVIRONMENT_VARIABLE not in os.environ: return - directives = json.load(sys.stdin) try: - result = cls(**directives["settings"])._run() + directives = json.load(sys.stdin) + sandbox = cls(**directives["settings"]) + if directives["stage"] == cls._STAGE_SANDBOX: + result = sandbox._run() + elif directives["stage"] == cls._STAGE_SNIPPET: + result = sandbox._run_snippets() + else: + raise ValueError(f"Invalid stage in directives: {directives}") except Exception as e: exception_info = { "name": e.__class__.__name__, @@ -1967,8 +1976,7 @@ def main(cls): def __init__( self, tmp_dir: str, - language: str, - code: str, + snippets: list[tuple], debug: bool, networking_allowed: bool, max_runtime_seconds: int, @@ -1980,8 +1988,7 @@ def __init__( Constructor. :param tmp_dir: Temporary directory exclusive to this sandbox. Must outlive the Sandbox object. - :param language: The language of the code; must be one of `SUPPORTED_LANGUAGES`. - :param code: Arbitrary code that needs to run in the sandbox. + :param snippets: A list of 2-tuples (language, code) to run inside the sandbox. :param debug: Whether or not to enable debug-level logging for the sandbox. :param networking_allowed: Whether the code should be given access to the network. :param max_runtime_seconds: How long the code should be allowed to run, in seconds. @@ -1992,8 +1999,7 @@ def __init__( self._init( { "tmp_dir": tmp_dir, - "language": language, - "code": code, + "snippets": snippets, "debug": debug, "networking_allowed": networking_allowed, "max_runtime_seconds": max_runtime_seconds, @@ -2011,8 +2017,7 @@ def _init(self, settings): self._logs_path = os.path.join(self._tmp_dir, "logs") self._gotmp_dir = os.path.join(self._tmp_dir, "gotmp") self._sandbox_shared_path = os.path.join(self._tmp_dir, "sandbox") - self._language = self._settings["language"] - self._code = self._settings["code"] + self._snippets = self._settings["snippets"] self._debug = self._settings["debug"] self._networking_allowed = self._settings["networking_allowed"] self._max_runtime_seconds = self._settings["max_runtime_seconds"] @@ -2074,17 +2079,12 @@ def _setup_sandbox(self): else: raise e.__class__(f"{e}; {switcheroo_status}") - # Locate the interpreter to use. - interpreter_path = sys.executable - if self._language == self.LANGUAGE_BASH: - interpreter_path = shutil.which("bash") - if interpreter_path is None: - raise RuntimeError("Interpreter not found") + # Mount the Python interpreter. oci_config["mounts"].append( { "type": "bind", - "source": interpreter_path, - "destination": interpreter_path, + "source": sys.executable, + "destination": sys.executable, "options": ["ro"], } ) @@ -2104,7 +2104,7 @@ def _setup_sandbox(self): ) # Create read-only empty directories. - for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(interpreter_path)]: + for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(sys.executable)]: rootfs_subdir = os.path.join(rootfs_path, d.removeprefix(os.path.sep)) os.makedirs(rootfs_subdir, mode=0o755, exist_ok=True) @@ -2155,7 +2155,7 @@ def _setup_sandbox(self): } ) - # Shared sandbox directory to propagate exit code and persistent files. + # Shared sandbox directory to propagate and persistent files. oci_config["mounts"].append( { "type": "bind", @@ -2164,6 +2164,8 @@ def _setup_sandbox(self): "options": ["rw"], } ) + with open(os.path.join(self._sandbox_shared_path, "self.py"), "w") as self_f: + self_f.write(self._SelfFile.contents()) if self._persistent_home_dir is not None: oci_config["mounts"].append( { @@ -2184,19 +2186,10 @@ def _setup_sandbox(self): passwd_f.write("user:x:1000:1000:user:/home/user:/bin/bash\n") # Generate command line to run in the sandbox. + oci_config["process"]["env"].append(f"{self._MARKER_ENVIRONMENT_VARIABLE}=1") self._sandboxed_command = [ - shutil.which("bash"), - "-c", - "; ".join( - ( - "echo OK > /sandbox/started", - f"{interpreter_path} /dev/stdin", - 'echo "$?" > /sandbox/.pre_exit_code || exit 1', - "if [[ -d /sandbox/persistent ]]; then cp -rd --one-file-system /home/user/. /sandbox/persistent/ || exit 2; fi", - "mv /sandbox/.pre_exit_code /sandbox/exit_code || exit 3", - "exit 0", - ) - ), + sys.executable, + "/sandbox/self.py", ] # Work around issue that gVisor does not preserve correct UID mappings when running as non-root user in the sandbox. @@ -2214,6 +2207,66 @@ def _setup_sandbox(self): with open(os.path.join(self._bundle_path, "config.json"), "w") as bundle_f: json.dump(oci_config, bundle_f, indent=2, sort_keys=True) + def _process_json_wrapped_result( + self, result: subprocess.CompletedProcess + ) -> subprocess.CompletedProcess: + """ + Process a `CompletedProcess` from a wrapped invocation. + + :param result: A `CompletedProcess` with stdout captured. + :return: A synthetic `CompletedProcess` from the JSON information in stdout. + :raises Sandbox.SandboxRuntimeException: If the JSON information cannot be interpreted. + :raises Sandbox.SandboxException: For any exception that is forwarded. + """ + if not result.stdout: + raise self.SandboxRuntimeException( + f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" + ) + try: + output = json.loads(result.stdout) + except json.decoder.JSONDecodeError as e: + raise self.SandboxRuntimeException( + f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" + ) + if "exception" in output: + class_name = output["exception"]["name"] + found_class = None + for ex_class in ( + self.PlatformNotSupportedException, + self.SandboxRuntimeException, + self.CodeExecutionError, + self.ExecutionTimeoutError, + self.InterruptedExecutionError, + self.GVisorNotInstalledException, + self.CorruptDownloadException, + self.EnvironmentNeedsSetupException, + self.ExecutionError, + self.SandboxException, + ): + if ex_class.__name__ == class_name: + found_class = ex_class + break + if found_class is None: + exception_str = output["exception"]["str"] + raise self.SandboxRuntimeException(f"{class_name}: {exception_str}") + raise found_class( + *output["exception"]["args"], **output["exception"]["kwargs"] + ) + if "result" not in output: + raise self.SandboxRuntimeException( + f"Invalid response from subprocess: {output}" + ) + return subprocess.CompletedProcess( + args=output["result"]["args"], + returncode=output["result"]["returncode"], + stdout=base64.b64decode(output["result"]["stdout"]).decode( + "utf-8", errors="replace" + ), + stderr=base64.b64decode(output["result"]["stderr"]).decode( + "utf-8", errors="replace" + ), + ) + def _run(self) -> subprocess.CompletedProcess: """ Spawn and wait for the sandbox. Runs in separate forked process. @@ -2242,6 +2295,12 @@ def _run(self) -> subprocess.CompletedProcess: ] runsc_env = os.environ.copy() runsc_env["TMPDIR"] = self._gotmp_dir + runsc_input = json.dumps( + { + "stage": self._STAGE_SNIPPET, + "settings": self._settings, + } + ) started_marker_path = os.path.join(self._sandbox_shared_path, "started") resource_monitor_cancel = self._switcheroo.monitor_cgroup_resources() try: @@ -2249,15 +2308,18 @@ def _run(self) -> subprocess.CompletedProcess: runsc_argv, env=runsc_env, preexec_fn=self._switcheroo.move_process_to_sandbox_leaf_cgroup_lambda(), - input=self._code + "\n", + input=runsc_input, text=True, capture_output=True, - timeout=self._max_runtime_seconds, + timeout=self._max_runtime_seconds + 3, check=True, ) except subprocess.TimeoutExpired as e: raise self.ExecutionTimeoutError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=126, cmd=self._sandboxed_command, output=e.stdout, @@ -2266,7 +2328,10 @@ def _run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: if os.path.isfile(started_marker_path): raise self.InterruptedExecutionError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=127, cmd=self._sandboxed_command, output=e.stdout, @@ -2298,31 +2363,94 @@ def process_log(filename, log_line): raise self.SandboxRuntimeException( "Sandbox failed to start up properly" ) - exit_code_path = os.path.join(self._sandbox_shared_path, "exit_code") - if not os.path.isfile(exit_code_path): + return self._process_json_wrapped_result(result) + finally: + if self._switcheroo is not None: + self._switcheroo.cleanup() + + def _run_snippets(self): + """ + Run all snippets in the sandbox. + This code is called from *within* the gVisor sandbox. + """ + with open("/sandbox/started", "wb") as started_f: + started_f.write(b"OK\n") + deadline = time.time() + self._max_runtime_seconds + last_result = None + overall_args = [] + overall_stdout = "" + overall_stderr = "" + if len(self._snippets) == 0: + raise self.SandboxRuntimeException("No code snippets to run") + for snippet in self._snippets: + if len(snippet) != 2: + raise self.SandboxRuntimeException(f"Invalid snippet: {snippet}") + language, code = snippet + if language not in self.SUPPORTED_LANGUAGES: + raise self.SandboxRuntimeException(f"Unsupported language: {language}") + interpreter_path = None + if language == self.LANGUAGE_BASH: + interpreter_path = shutil.which("bash") + elif language == self.LANGUAGE_PYTHON: + interpreter_path = sys.executable + if interpreter_path is None: raise self.SandboxRuntimeException( - "Sandbox failed to record an exit code" + f"Cannot find interpreter for language: {language}" + ) + cmd = [interpreter_path, "/dev/stdin"] + overall_args.append(" ".join(cmd)) + snippet_timeout = deadline - time.time() + if snippet_timeout <= 0.0: + raise self.ExecutionTimeoutError( + f"Code executed the deadline of {self._max_runtime_seconds} seconds" ) - with open(exit_code_path, "r") as exit_code_f: - exit_code_str = exit_code_f.read() try: - exit_code = int(exit_code_str.strip()) - except ValueError as e: - raise self.SandboxRuntimeException( - f"Sandbox recorded non-integer exit code: {e}" + snippet_result = subprocess.run( + cmd, + input=code + "\n", + text=True, + capture_output=True, + timeout=snippet_timeout, + check=True, ) - if exit_code != 0: + except subprocess.TimeoutExpired as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" + raise self.ExecutionTimeoutError( + code=code, + returncode=126, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, + ) + except subprocess.CalledProcessError as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" raise self.CodeExecutionError( - code=self._code, - returncode=exit_code, - cmd=self._sandboxed_command, - output=result.stdout, - stderr=result.stderr, + code=code, + returncode=e.returncode, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, ) - return result - finally: - if self._switcheroo is not None: - self._switcheroo.cleanup() + else: + last_result = snippet_result + overall_stdout += snippet_result.stdout or "" + overall_stderr += snippet_result.stderr or "" + assert last_result is not None, "Logic error" + if os.path.isdir("/sandbox/persistent"): + shutil.copytree( + "/home/user", + "/sandbox/persistent", + ignore_dangling_symlinks=True, + dirs_exist_ok=True, + ) + return subprocess.CompletedProcess( + args=["sh", "-c", "; ".join(overall_args)], + returncode=0, + stdout=overall_stdout, + stderr=overall_stderr, + ) def run(self) -> subprocess.CompletedProcess: """ @@ -2340,12 +2468,17 @@ def run(self) -> subprocess.CompletedProcess: reexec_f.write(self._SelfFile.contents()) new_env = os.environ.copy() new_env[self._MARKER_ENVIRONMENT_VARIABLE] = "1" - data = json.dumps({"settings": self._settings}) + directives = json.dumps( + { + "stage": self._STAGE_SANDBOX, + "settings": self._settings, + } + ) try: result = subprocess.run( (sys.executable, reexec_path), env=new_env, - input=data, + input=directives, text=True, capture_output=True, check=True, @@ -2353,54 +2486,7 @@ def run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: raise self.SandboxRuntimeException(f"{e} (stderr: {e.stderr})") else: - if not result.stdout: - raise self.SandboxRuntimeException( - f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" - ) - try: - output = json.loads(result.stdout) - except json.decoder.JSONDecodeError as e: - raise self.SandboxRuntimeException( - f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" - ) - if "exception" in output: - class_name = output["exception"]["name"] - found_class = None - for ex_class in ( - self.PlatformNotSupportedException, - self.SandboxRuntimeException, - self.CodeExecutionError, - self.ExecutionTimeoutError, - self.InterruptedExecutionError, - self.GVisorNotInstalledException, - self.CorruptDownloadException, - self.EnvironmentNeedsSetupException, - self.ExecutionError, - self.SandboxException, - ): - if ex_class.__name__ == class_name: - found_class = ex_class - break - if found_class is None: - exception_str = output["exception"]["str"] - raise self.SandboxException(f"{class_name}: {exception_str}") - raise found_class( - *output["exception"]["args"], **output["exception"]["kwargs"] - ) - if "result" not in output: - raise self.SandboxException( - f"Invalid response from subprocess: {output}" - ) - return subprocess.CompletedProcess( - args=output["result"]["args"], - returncode=output["result"]["returncode"], - stdout=base64.b64decode(output["result"]["stdout"]).decode( - "utf-8", errors="replace" - ), - stderr=base64.b64decode(output["result"]["stderr"]).decode( - "utf-8", errors="replace" - ), - ) + return self._process_json_wrapped_result(result) def debug_logs(self, write_fn: typing.Callable[[str, str], typing.Any]): """ diff --git a/src/openwebui/functions/run_code.py b/src/openwebui/functions/run_code.py index 5053686..6668cd1 100644 --- a/src/openwebui/functions/run_code.py +++ b/src/openwebui/functions/run_code.py @@ -264,8 +264,7 @@ async def _fail(error_message, status="SANDBOX_ERROR"): sandbox = Sandbox( tmp_dir=tmp_dir, - language=language, - code=code, + snippets=((language, code),), debug=debug, networking_allowed=valves.NETWORKING_ALLOWED, max_runtime_seconds=valves.MAX_RUNTIME_SECONDS, diff --git a/src/openwebui/tools/run_code.py b/src/openwebui/tools/run_code.py index b706dc6..454ec9d 100644 --- a/src/openwebui/tools/run_code.py +++ b/src/openwebui/tools/run_code.py @@ -210,8 +210,7 @@ async def _fail(error_message, status="SANDBOX_ERROR"): with tempfile.TemporaryDirectory(prefix="sandbox_") as tmp_dir: sandbox = Sandbox( tmp_dir=tmp_dir, - language=language, - code=code, + snippets=((language, code),), debug=debug, networking_allowed=valves.NETWORKING_ALLOWED, max_runtime_seconds=valves.MAX_RUNTIME_SECONDS, diff --git a/src/safecode/sandbox.py b/src/safecode/sandbox.py index 493f76e..b471dc0 100644 --- a/src/safecode/sandbox.py +++ b/src/safecode/sandbox.py @@ -188,6 +188,10 @@ class Sandbox: # Environment variable used to detect interpreter re-execution. _MARKER_ENVIRONMENT_VARIABLE = "__CODE_EXECUTION_STAGE" + # Re-execution stages. + _STAGE_SANDBOX = "SANDBOX" + _STAGE_SNIPPET = "SNIPPET" + # libc bindings. # Populated using `_libc`. _LIBC = None @@ -1481,9 +1485,15 @@ def main(cls): cls._SelfFile.init() if cls._MARKER_ENVIRONMENT_VARIABLE not in os.environ: return - directives = json.load(sys.stdin) try: - result = cls(**directives["settings"])._run() + directives = json.load(sys.stdin) + sandbox = cls(**directives["settings"]) + if directives["stage"] == cls._STAGE_SANDBOX: + result = sandbox._run() + elif directives["stage"] == cls._STAGE_SNIPPET: + result = sandbox._run_snippets() + else: + raise ValueError(f"Invalid stage in directives: {directives}") except Exception as e: exception_info = { "name": e.__class__.__name__, @@ -1523,8 +1533,7 @@ def main(cls): def __init__( self, tmp_dir: str, - language: str, - code: str, + snippets: list[tuple], debug: bool, networking_allowed: bool, max_runtime_seconds: int, @@ -1536,8 +1545,7 @@ def __init__( Constructor. :param tmp_dir: Temporary directory exclusive to this sandbox. Must outlive the Sandbox object. - :param language: The language of the code; must be one of `SUPPORTED_LANGUAGES`. - :param code: Arbitrary code that needs to run in the sandbox. + :param snippets: A list of 2-tuples (language, code) to run inside the sandbox. :param debug: Whether or not to enable debug-level logging for the sandbox. :param networking_allowed: Whether the code should be given access to the network. :param max_runtime_seconds: How long the code should be allowed to run, in seconds. @@ -1548,8 +1556,7 @@ def __init__( self._init( { "tmp_dir": tmp_dir, - "language": language, - "code": code, + "snippets": snippets, "debug": debug, "networking_allowed": networking_allowed, "max_runtime_seconds": max_runtime_seconds, @@ -1567,8 +1574,7 @@ def _init(self, settings): self._logs_path = os.path.join(self._tmp_dir, "logs") self._gotmp_dir = os.path.join(self._tmp_dir, "gotmp") self._sandbox_shared_path = os.path.join(self._tmp_dir, "sandbox") - self._language = self._settings["language"] - self._code = self._settings["code"] + self._snippets = self._settings["snippets"] self._debug = self._settings["debug"] self._networking_allowed = self._settings["networking_allowed"] self._max_runtime_seconds = self._settings["max_runtime_seconds"] @@ -1630,17 +1636,12 @@ def _setup_sandbox(self): else: raise e.__class__(f"{e}; {switcheroo_status}") - # Locate the interpreter to use. - interpreter_path = sys.executable - if self._language == self.LANGUAGE_BASH: - interpreter_path = shutil.which("bash") - if interpreter_path is None: - raise RuntimeError("Interpreter not found") + # Mount the Python interpreter. oci_config["mounts"].append( { "type": "bind", - "source": interpreter_path, - "destination": interpreter_path, + "source": sys.executable, + "destination": sys.executable, "options": ["ro"], } ) @@ -1660,7 +1661,7 @@ def _setup_sandbox(self): ) # Create read-only empty directories. - for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(interpreter_path)]: + for d in self.EMPTY_READ_ONLY_DIRECTORIES + [os.path.dirname(sys.executable)]: rootfs_subdir = os.path.join(rootfs_path, d.removeprefix(os.path.sep)) os.makedirs(rootfs_subdir, mode=0o755, exist_ok=True) @@ -1711,7 +1712,7 @@ def _setup_sandbox(self): } ) - # Shared sandbox directory to propagate exit code and persistent files. + # Shared sandbox directory to propagate and persistent files. oci_config["mounts"].append( { "type": "bind", @@ -1720,6 +1721,8 @@ def _setup_sandbox(self): "options": ["rw"], } ) + with open(os.path.join(self._sandbox_shared_path, "self.py"), "w") as self_f: + self_f.write(self._SelfFile.contents()) if self._persistent_home_dir is not None: oci_config["mounts"].append( { @@ -1740,19 +1743,10 @@ def _setup_sandbox(self): passwd_f.write("user:x:1000:1000:user:/home/user:/bin/bash\n") # Generate command line to run in the sandbox. + oci_config["process"]["env"].append(f"{self._MARKER_ENVIRONMENT_VARIABLE}=1") self._sandboxed_command = [ - shutil.which("bash"), - "-c", - "; ".join( - ( - "echo OK > /sandbox/started", - f"{interpreter_path} /dev/stdin", - 'echo "$?" > /sandbox/.pre_exit_code || exit 1', - "if [[ -d /sandbox/persistent ]]; then cp -rd --one-file-system /home/user/. /sandbox/persistent/ || exit 2; fi", - "mv /sandbox/.pre_exit_code /sandbox/exit_code || exit 3", - "exit 0", - ) - ), + sys.executable, + "/sandbox/self.py", ] # Work around issue that gVisor does not preserve correct UID mappings when running as non-root user in the sandbox. @@ -1770,6 +1764,66 @@ def _setup_sandbox(self): with open(os.path.join(self._bundle_path, "config.json"), "w") as bundle_f: json.dump(oci_config, bundle_f, indent=2, sort_keys=True) + def _process_json_wrapped_result( + self, result: subprocess.CompletedProcess + ) -> subprocess.CompletedProcess: + """ + Process a `CompletedProcess` from a wrapped invocation. + + :param result: A `CompletedProcess` with stdout captured. + :return: A synthetic `CompletedProcess` from the JSON information in stdout. + :raises Sandbox.SandboxRuntimeException: If the JSON information cannot be interpreted. + :raises Sandbox.SandboxException: For any exception that is forwarded. + """ + if not result.stdout: + raise self.SandboxRuntimeException( + f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" + ) + try: + output = json.loads(result.stdout) + except json.decoder.JSONDecodeError as e: + raise self.SandboxRuntimeException( + f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" + ) + if "exception" in output: + class_name = output["exception"]["name"] + found_class = None + for ex_class in ( + self.PlatformNotSupportedException, + self.SandboxRuntimeException, + self.CodeExecutionError, + self.ExecutionTimeoutError, + self.InterruptedExecutionError, + self.GVisorNotInstalledException, + self.CorruptDownloadException, + self.EnvironmentNeedsSetupException, + self.ExecutionError, + self.SandboxException, + ): + if ex_class.__name__ == class_name: + found_class = ex_class + break + if found_class is None: + exception_str = output["exception"]["str"] + raise self.SandboxRuntimeException(f"{class_name}: {exception_str}") + raise found_class( + *output["exception"]["args"], **output["exception"]["kwargs"] + ) + if "result" not in output: + raise self.SandboxRuntimeException( + f"Invalid response from subprocess: {output}" + ) + return subprocess.CompletedProcess( + args=output["result"]["args"], + returncode=output["result"]["returncode"], + stdout=base64.b64decode(output["result"]["stdout"]).decode( + "utf-8", errors="replace" + ), + stderr=base64.b64decode(output["result"]["stderr"]).decode( + "utf-8", errors="replace" + ), + ) + def _run(self) -> subprocess.CompletedProcess: """ Spawn and wait for the sandbox. Runs in separate forked process. @@ -1798,6 +1852,12 @@ def _run(self) -> subprocess.CompletedProcess: ] runsc_env = os.environ.copy() runsc_env["TMPDIR"] = self._gotmp_dir + runsc_input = json.dumps( + { + "stage": self._STAGE_SNIPPET, + "settings": self._settings, + } + ) started_marker_path = os.path.join(self._sandbox_shared_path, "started") resource_monitor_cancel = self._switcheroo.monitor_cgroup_resources() try: @@ -1805,15 +1865,18 @@ def _run(self) -> subprocess.CompletedProcess: runsc_argv, env=runsc_env, preexec_fn=self._switcheroo.move_process_to_sandbox_leaf_cgroup_lambda(), - input=self._code + "\n", + input=runsc_input, text=True, capture_output=True, - timeout=self._max_runtime_seconds, + timeout=self._max_runtime_seconds + 3, check=True, ) except subprocess.TimeoutExpired as e: raise self.ExecutionTimeoutError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=126, cmd=self._sandboxed_command, output=e.stdout, @@ -1822,7 +1885,10 @@ def _run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: if os.path.isfile(started_marker_path): raise self.InterruptedExecutionError( - code=self._code, + code="; ".join( + f"({language}, {repr(code)}))" + for language, code in self._snippets + ), returncode=127, cmd=self._sandboxed_command, output=e.stdout, @@ -1854,31 +1920,94 @@ def process_log(filename, log_line): raise self.SandboxRuntimeException( "Sandbox failed to start up properly" ) - exit_code_path = os.path.join(self._sandbox_shared_path, "exit_code") - if not os.path.isfile(exit_code_path): + return self._process_json_wrapped_result(result) + finally: + if self._switcheroo is not None: + self._switcheroo.cleanup() + + def _run_snippets(self): + """ + Run all snippets in the sandbox. + This code is called from *within* the gVisor sandbox. + """ + with open("/sandbox/started", "wb") as started_f: + started_f.write(b"OK\n") + deadline = time.time() + self._max_runtime_seconds + last_result = None + overall_args = [] + overall_stdout = "" + overall_stderr = "" + if len(self._snippets) == 0: + raise self.SandboxRuntimeException("No code snippets to run") + for snippet in self._snippets: + if len(snippet) != 2: + raise self.SandboxRuntimeException(f"Invalid snippet: {snippet}") + language, code = snippet + if language not in self.SUPPORTED_LANGUAGES: + raise self.SandboxRuntimeException(f"Unsupported language: {language}") + interpreter_path = None + if language == self.LANGUAGE_BASH: + interpreter_path = shutil.which("bash") + elif language == self.LANGUAGE_PYTHON: + interpreter_path = sys.executable + if interpreter_path is None: raise self.SandboxRuntimeException( - "Sandbox failed to record an exit code" + f"Cannot find interpreter for language: {language}" + ) + cmd = [interpreter_path, "/dev/stdin"] + overall_args.append(" ".join(cmd)) + snippet_timeout = deadline - time.time() + if snippet_timeout <= 0.0: + raise self.ExecutionTimeoutError( + f"Code executed the deadline of {self._max_runtime_seconds} seconds" ) - with open(exit_code_path, "r") as exit_code_f: - exit_code_str = exit_code_f.read() try: - exit_code = int(exit_code_str.strip()) - except ValueError as e: - raise self.SandboxRuntimeException( - f"Sandbox recorded non-integer exit code: {e}" + snippet_result = subprocess.run( + cmd, + input=code + "\n", + text=True, + capture_output=True, + timeout=snippet_timeout, + check=True, + ) + except subprocess.TimeoutExpired as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" + raise self.ExecutionTimeoutError( + code=code, + returncode=126, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, ) - if exit_code != 0: + except subprocess.CalledProcessError as e: + overall_stdout += e.stdout or "" + overall_stderr += e.stderr or "" raise self.CodeExecutionError( - code=self._code, - returncode=exit_code, - cmd=self._sandboxed_command, - output=result.stdout, - stderr=result.stderr, + code=code, + returncode=e.returncode, + cmd=["sh", "-c", "; ".join(overall_args)], + output=overall_stdout, + stderr=overall_stderr, ) - return result - finally: - if self._switcheroo is not None: - self._switcheroo.cleanup() + else: + last_result = snippet_result + overall_stdout += snippet_result.stdout or "" + overall_stderr += snippet_result.stderr or "" + assert last_result is not None, "Logic error" + if os.path.isdir("/sandbox/persistent"): + shutil.copytree( + "/home/user", + "/sandbox/persistent", + ignore_dangling_symlinks=True, + dirs_exist_ok=True, + ) + return subprocess.CompletedProcess( + args=["sh", "-c", "; ".join(overall_args)], + returncode=0, + stdout=overall_stdout, + stderr=overall_stderr, + ) def run(self) -> subprocess.CompletedProcess: """ @@ -1896,12 +2025,17 @@ def run(self) -> subprocess.CompletedProcess: reexec_f.write(self._SelfFile.contents()) new_env = os.environ.copy() new_env[self._MARKER_ENVIRONMENT_VARIABLE] = "1" - data = json.dumps({"settings": self._settings}) + directives = json.dumps( + { + "stage": self._STAGE_SANDBOX, + "settings": self._settings, + } + ) try: result = subprocess.run( (sys.executable, reexec_path), env=new_env, - input=data, + input=directives, text=True, capture_output=True, check=True, @@ -1909,54 +2043,7 @@ def run(self) -> subprocess.CompletedProcess: except subprocess.CalledProcessError as e: raise self.SandboxRuntimeException(f"{e} (stderr: {e.stderr})") else: - if not result.stdout: - raise self.SandboxRuntimeException( - f"Subprocess interpreter did not produce any output (stderr: {result.stderr})" - ) - try: - output = json.loads(result.stdout) - except json.decoder.JSONDecodeError as e: - raise self.SandboxRuntimeException( - f"Subprocess interpreter produced invalid JSON (stdout: {result.stdout}): {e}" - ) - if "exception" in output: - class_name = output["exception"]["name"] - found_class = None - for ex_class in ( - self.PlatformNotSupportedException, - self.SandboxRuntimeException, - self.CodeExecutionError, - self.ExecutionTimeoutError, - self.InterruptedExecutionError, - self.GVisorNotInstalledException, - self.CorruptDownloadException, - self.EnvironmentNeedsSetupException, - self.ExecutionError, - self.SandboxException, - ): - if ex_class.__name__ == class_name: - found_class = ex_class - break - if found_class is None: - exception_str = output["exception"]["str"] - raise self.SandboxException(f"{class_name}: {exception_str}") - raise found_class( - *output["exception"]["args"], **output["exception"]["kwargs"] - ) - if "result" not in output: - raise self.SandboxException( - f"Invalid response from subprocess: {output}" - ) - return subprocess.CompletedProcess( - args=output["result"]["args"], - returncode=output["result"]["returncode"], - stdout=base64.b64decode(output["result"]["stdout"]).decode( - "utf-8", errors="replace" - ), - stderr=base64.b64decode(output["result"]["stderr"]).decode( - "utf-8", errors="replace" - ), - ) + return self._process_json_wrapped_result(result) def debug_logs(self, write_fn: typing.Callable[[str, str], typing.Any]): """