Skip to content

Commit

Permalink
Refactor cli module to smaller files (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
amritghimire authored Jan 4, 2025
1 parent 8dfa4ff commit f79173a
Show file tree
Hide file tree
Showing 18 changed files with 1,673 additions and 1,538 deletions.
1,518 changes: 0 additions & 1,518 deletions src/datachain/cli.py

This file was deleted.

311 changes: 311 additions & 0 deletions src/datachain/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import logging
import os
import sys
import traceback
from multiprocessing import freeze_support
from typing import Optional

from datachain.cli.utils import get_logging_level
from datachain.telemetry import telemetry

from .commands import (
clear_cache,
completion,
dataset_stats,
du,
edit_dataset,
garbage_collect,
index,
list_datasets,
ls,
query,
rm_dataset,
show,
)
from .parser import get_parser

logger = logging.getLogger("datachain")


def main(argv: Optional[list[str]] = None) -> int:
from datachain.catalog import get_catalog

# Required for Windows multiprocessing support
freeze_support()

datachain_parser = get_parser()
args = datachain_parser.parse_args(argv)

if args.command in ("internal-run-udf", "internal-run-udf-worker"):
return handle_udf(args.command)

logger.addHandler(logging.StreamHandler())
logging_level = get_logging_level(args)
logger.setLevel(logging_level)

client_config = {
"aws_endpoint_url": args.aws_endpoint_url,
"anon": args.anon,
}

if args.debug_sql:
# This also sets this environment variable for any subprocesses
os.environ["DEBUG_SHOW_SQL_QUERIES"] = "True"

error = None

try:
catalog = get_catalog(client_config=client_config)
return handle_command(args, catalog, client_config)
except BrokenPipeError as exc:
error, return_code = handle_broken_pipe_error(exc)
return return_code
except (KeyboardInterrupt, Exception) as exc:
error, return_code = handle_general_exception(exc, args, logging_level)
return return_code
finally:
telemetry.send_cli_call(args.command, error=error)


def handle_command(args, catalog, client_config) -> int:
"""Handle the different CLI commands."""
from datachain.studio import process_jobs_args, process_studio_cli_args

command_handlers = {
"cp": lambda: handle_cp_command(args, catalog),
"clone": lambda: handle_clone_command(args, catalog),
"dataset": lambda: handle_dataset_command(args, catalog),
"ds": lambda: handle_dataset_command(args, catalog),
"ls": lambda: handle_ls_command(args, client_config),
"show": lambda: handle_show_command(args, catalog),
"du": lambda: handle_du_command(args, catalog, client_config),
"find": lambda: handle_find_command(args, catalog),
"index": lambda: handle_index_command(args, catalog),
"completion": lambda: handle_completion_command(args),
"query": lambda: handle_query_command(args, catalog),
"clear-cache": lambda: clear_cache(catalog),
"gc": lambda: garbage_collect(catalog),
"studio": lambda: process_studio_cli_args(args),
"job": lambda: process_jobs_args(args),
}

handler = command_handlers.get(args.command)
if handler:
handler()
return 0
print(f"invalid command: {args.command}", file=sys.stderr)
return 1


def handle_cp_command(args, catalog):
catalog.cp(
args.sources,
args.output,
force=bool(args.force),
update=bool(args.update),
recursive=bool(args.recursive),
edatachain_file=None,
edatachain_only=False,
no_edatachain_file=True,
no_glob=args.no_glob,
)


def handle_clone_command(args, catalog):
catalog.clone(
args.sources,
args.output,
force=bool(args.force),
update=bool(args.update),
recursive=bool(args.recursive),
no_glob=args.no_glob,
no_cp=args.no_cp,
edatachain=args.edatachain,
edatachain_file=args.edatachain_file,
)


def handle_dataset_command(args, catalog):
dataset_commands = {
"pull": lambda: catalog.pull_dataset(
args.dataset,
args.output,
local_ds_name=args.local_name,
local_ds_version=args.local_version,
cp=args.cp,
force=bool(args.force),
edatachain=args.edatachain,
edatachain_file=args.edatachain_file,
),
"edit": lambda: edit_dataset(
catalog,
args.name,
new_name=args.new_name,
description=args.description,
labels=args.labels,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
),
"ls": lambda: list_datasets(
catalog=catalog,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
),
"rm": lambda: rm_dataset(
catalog,
args.name,
version=args.version,
force=args.force,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
),
"remove": lambda: rm_dataset(
catalog,
args.name,
version=args.version,
force=args.force,
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
),
"stats": lambda: dataset_stats(
catalog,
args.name,
args.version,
show_bytes=args.bytes,
si=args.si,
),
}

handler = dataset_commands.get(args.datasets_cmd)
if handler:
return handler()
raise Exception(f"Unexpected command {args.datasets_cmd}")


def handle_ls_command(args, client_config):
ls(
args.sources,
long=bool(args.long),
studio=args.studio,
local=args.local,
all=args.all,
team=args.team,
update=bool(args.update),
client_config=client_config,
)


def handle_show_command(args, catalog):
show(
catalog,
args.name,
args.version,
limit=args.limit,
offset=args.offset,
columns=args.columns,
no_collapse=args.no_collapse,
schema=args.schema,
)


def handle_du_command(args, catalog, client_config):
du(
catalog,
args.sources,
show_bytes=args.bytes,
depth=args.depth,
si=args.si,
update=bool(args.update),
client_config=client_config,
)


def handle_find_command(args, catalog):
results_found = False
for result in catalog.find(
args.sources,
update=bool(args.update),
names=args.name,
inames=args.iname,
paths=args.path,
ipaths=args.ipath,
size=args.size,
typ=args.type,
columns=args.columns,
):
print(result)
results_found = True
if not results_found:
print("No results")


def handle_index_command(args, catalog):
index(
catalog,
args.sources,
update=bool(args.update),
)


def handle_completion_command(args):
print(completion(args.shell))


def handle_query_command(args, catalog):
query(
catalog,
args.script,
parallel=args.parallel,
params=args.param,
)


def handle_broken_pipe_error(exc):
# Python flushes standard streams on exit; redirect remaining output
# to devnull to avoid another BrokenPipeError at shutdown
# See: https://docs.python.org/3/library/signal.html#note-on-sigpipe
error = str(exc)
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
return error, 141 # 128 + 13 (SIGPIPE)


def handle_general_exception(exc, args, logging_level):
error = str(exc)
if isinstance(exc, KeyboardInterrupt):
msg = "Operation cancelled by the user"
else:
msg = str(exc)
print("Error:", msg, file=sys.stderr)
if logging_level <= logging.DEBUG:
traceback.print_exception(
type(exc),
exc,
exc.__traceback__,
file=sys.stderr,
)
if args.pdb:
import pdb # noqa: T100

pdb.post_mortem()
return error, 1


def handle_udf(command):
if command == "internal-run-udf":
from datachain.query.dispatch import udf_entrypoint

return udf_entrypoint()

if command == "internal-run-udf-worker":
from datachain.query.dispatch import udf_worker_entrypoint

return udf_worker_entrypoint()
29 changes: 29 additions & 0 deletions src/datachain/cli/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .datasets import (
dataset_stats,
edit_dataset,
list_datasets,
list_datasets_local,
rm_dataset,
)
from .du import du
from .index import index
from .ls import ls
from .misc import clear_cache, completion, garbage_collect
from .query import query
from .show import show

__all__ = [
"clear_cache",
"completion",
"dataset_stats",
"du",
"edit_dataset",
"garbage_collect",
"index",
"list_datasets",
"list_datasets_local",
"ls",
"query",
"rm_dataset",
"show",
]
Loading

0 comments on commit f79173a

Please sign in to comment.