Skip to content

Commit

Permalink
don't depend on datachain from PATH to exec processes (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jul 23, 2024
1 parent 683661d commit 5312913
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
18 changes: 7 additions & 11 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shlex
import sys
import traceback
from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace
from collections.abc import Iterable, Iterator, Mapping, Sequence
from importlib.metadata import PackageNotFoundError, version
from itertools import chain
Expand Down Expand Up @@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
parser = ArgumentParser(
description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
)

parser.add_argument("-V", "--version", action="version", version=__version__)
parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS)
parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS)

parent_parser = ArgumentParser(add_help=False)
parent_parser.add_argument(
Expand Down Expand Up @@ -155,6 +152,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
metavar="command",
dest="command",
help=f"Use `{parser.prog} command --help` for command-specific help.",
required=True,
)
parse_cp = subp.add_parser(
"cp", parents=[parent_parser], description="Copy data files from the cloud"
Expand Down Expand Up @@ -556,6 +554,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
"gc", parents=[parent_parser], description="Garbage collect temporary tables"
)

subp.add_parser("internal-run-udf", parents=[parent_parser])
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
add_completion_parser(subp, [parent_parser])
return parser

Expand Down Expand Up @@ -910,27 +910,23 @@ def completion(shell: str) -> str:
)


def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR0912, PLR0915
def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915
# Required for Windows multiprocessing support
freeze_support()

parser = get_parser()
args = parser.parse_args(argv)

if args.internal_run_udf:
if args.command == "internal-run-udf":
from datachain.query.dispatch import udf_entrypoint

return udf_entrypoint()

if args.internal_run_udf_worker:
if args.command == "internal-run-udf-worker":
from datachain.query.dispatch import udf_worker_entrypoint

return udf_worker_entrypoint()

if args.command is None:
parser.print_help()
return 1

from .catalog import get_catalog

logger.addHandler(logging.StreamHandler())
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
batched,
determine_processes,
filtered_cloudpickle_dumps,
get_datachain_executable,
)

from .metrics import metrics
Expand Down Expand Up @@ -507,13 +508,12 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

# Run the UDFDispatcher in another process to avoid needing
# if __name__ == '__main__': in user scripts
datachain_exec_path = os.environ.get("DATACHAIN_EXEC_PATH", "datachain")

exec_cmd = get_datachain_executable()
envs = dict(os.environ)
envs.update({"PYTHONPATH": os.getcwd()})
process_data = filtered_cloudpickle_dumps(udf_info)
result = subprocess.run( # noqa: S603
[datachain_exec_path, "--internal-run-udf"],
[*exec_cmd, "internal-run-udf"],
input=process_data,
check=False,
env=envs,
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,9 @@ def filtered_cloudpickle_dumps(obj: Any) -> bytes:
for model_class, namespace in model_namespaces.items():
# Restore original __pydantic_parent_namespace__ locally.
model_class.__pydantic_parent_namespace__ = namespace


def get_datachain_executable() -> list[str]:
if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"):
return [datachain_exec_path]
return [sys.executable, "-m", "datachain"]

0 comments on commit 5312913

Please sign in to comment.