Skip to content

Commit

Permalink
cli: debug: finalize profiler reports on errors
Browse files Browse the repository at this point in the history
Even if a command failed, it might still be useful to look at the profiler report.
  • Loading branch information
efiop committed Dec 13, 2023
1 parent a02c5a7 commit 4392af6
Showing 1 changed file with 46 additions and 40 deletions.
86 changes: 46 additions & 40 deletions dvc/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def viztracer_profile(
tracer = viztracer.VizTracer(max_stack_depth=depth, log_async=log_async)

tracer.start()
yield
tracer.stop()

tracer.save(path() if callable(path) else path)
try:
yield
finally:
tracer.stop()
tracer.save(path() if callable(path) else path)


@contextmanager
Expand All @@ -45,31 +46,33 @@ def yappi_profile(
yappi.set_clock_type("wall" if wall_clock else "cpu")

yappi.start()
yield
yappi.stop()

threads = yappi.get_thread_stats()
stats = {}
if separate_threads:
for thread in threads:
ctx_id = thread.id
stats[ctx_id] = yappi.get_func_stats(ctx_id=ctx_id)
else:
stats[None] = yappi.get_func_stats()

fpath = path() if callable(path) else path
for ctx_id, st in stats.items():
if fpath:
out = f"{fpath}-{ctx_id}" if ctx_id is not None else fpath
st.save(out, type="callgrind")
try:
yield
finally:
yappi.stop()

threads = yappi.get_thread_stats()
stats = {}
if separate_threads:
for thread in threads:
ctx_id = thread.id
stats[ctx_id] = yappi.get_func_stats(ctx_id=ctx_id)
else:
if ctx_id is not None:
print(f"\nThread {ctx_id}") # noqa: T201
st.print_all()
if ctx_id is None:
threads.print_all()
stats[None] = yappi.get_func_stats()

fpath = path() if callable(path) else path
for ctx_id, st in stats.items():
if fpath:
out = f"{fpath}-{ctx_id}" if ctx_id is not None else fpath
st.save(out, type="callgrind")
else:
if ctx_id is not None:
print(f"\nThread {ctx_id}") # noqa: T201
st.print_all()
if ctx_id is None:
threads.print_all()

yappi.clear_stats()
yappi.clear_stats()


@contextmanager
Expand All @@ -85,13 +88,15 @@ def instrument(html_output=False):
profiler = Profiler()

profiler.start()
yield
profiler.stop()
try:
yield
finally:
profiler.stop()

if html_output:
profiler.open_in_browser()
return
print(profiler.output_text(unicode=True, color=True)) # noqa: T201
if html_output:
profiler.open_in_browser()
else:
print(profiler.output_text(unicode=True, color=True)) # noqa: T201


@contextmanager
Expand All @@ -102,13 +107,14 @@ def profile(dump_path: Optional[str] = None):
prof = cProfile.Profile()
prof.enable()

yield

prof.disable()
if not dump_path:
prof.print_stats(sort="cumtime")
return
prof.dump_stats(dump_path)
try:
yield
finally:
prof.disable()
if dump_path:
prof.dump_stats(dump_path)
else:
prof.print_stats(sort="cumtime")


@contextmanager
Expand Down

0 comments on commit 4392af6

Please sign in to comment.