Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC][DNM] Experimental to/from_worker_storage API #1299

Draft
wants to merge 66 commits into
base: branch-24.02
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
b04fd04
Merge pull request #158 from rapidsai/branch-0.10
raydouglass Oct 17, 2019
5b78860
REL v0.10.0 release
GPUtester Oct 17, 2019
37f1934
Add ucx-py dependency to CI (#212)
raydouglass Dec 6, 2019
0c46eff
Merge pull request #206 from rapidsai/branch-0.11
raydouglass Dec 11, 2019
f2bbfa4
REL v0.11.0 release
GPUtester Dec 11, 2019
c162a22
Merge pull request #230 from rapidsai/branch-0.12
raydouglass Feb 4, 2020
9b76d88
REL v0.12.0 release
GPUtester Feb 4, 2020
49b8f2f
Merge pull request #267 from rapidsai/branch-0.13
raydouglass Mar 31, 2020
7f94db5
REL v0.13.0 release
GPUtester Mar 31, 2020
9db4453
Merge pull request #298 from rapidsai/branch-0.14
raydouglass Jun 3, 2020
d059ffc
REL v0.14.0 release
GPUtester Jun 3, 2020
b234fa5
Only create Security object if TLS files are specified
pentschev Jun 19, 2020
a860a1b
Fix argument tls_key argument name
pentschev Jun 19, 2020
43de9b2
Merge pull request #321 from raydouglass/tls-fix-backport
raydouglass Jun 22, 2020
3fc6db4
REL v0.14.1 release
GPUtester Jun 22, 2020
a488e5e
Merge pull request #390 from rapidsai/branch-0.15
raydouglass Aug 26, 2020
0275957
REL v0.15.0 release
GPUtester Aug 26, 2020
4ec39de
Merge pull request #418 from rapidsai/branch-0.16
raydouglass Oct 21, 2020
d714829
REL v0.16.0 release
GPUtester Oct 21, 2020
07e2543
Merge pull request #469 from rapidsai/branch-0.17
ajschmidt8 Dec 10, 2020
81371d5
REL v0.17.0 release
GPUtester Dec 10, 2020
603b58d
Merge pull request #535 from rapidsai/branch-0.18
raydouglass Feb 24, 2021
89b82cf
REL v0.18.0 release
GPUtester Feb 24, 2021
97f5193
Merge pull request #586 from rapidsai/branch-0.19
raydouglass Apr 21, 2021
1acf55e
REL v0.19.0 release
GPUtester Apr 21, 2021
597f54e
Merge pull request #640 from rapidsai/branch-21.06
raydouglass Jun 9, 2021
e5bf324
REL v21.06.00 release
GPUtester Jun 9, 2021
ad40dab
Merge pull request #652 from rapidsai/branch-21.06
ajschmidt8 Jun 10, 2021
adb7fb2
Merge pull request #696 from rapidsai/branch-21.08
raydouglass Aug 4, 2021
1287a15
REL v21.08.00 release
GPUtester Aug 4, 2021
0bcf9dc
Merge pull request #749 from rapidsai/branch-21.10
ajschmidt8 Oct 6, 2021
5311c1a
REL v21.10.00 release
GPUtester Oct 6, 2021
0ad2f74
Merge pull request #811 from rapidsai/branch-21.12
raydouglass Dec 8, 2021
e1e49b6
REL v21.12.00 release
GPUtester Dec 8, 2021
6070be5
Merge pull request #839 from rapidsai/branch-22.02
raydouglass Feb 2, 2022
a666e9b
REL v22.02.00 release
GPUtester Feb 2, 2022
29a8e29
Merge pull request #888 from rapidsai/branch-22.04
raydouglass Apr 6, 2022
451b3b3
REL v22.04.00 release
GPUtester Apr 6, 2022
a6b298d
Merge pull request #926 from rapidsai/branch-22.06
raydouglass Jun 7, 2022
2992966
REL v22.06.00 release
GPUtester Jun 7, 2022
9860cad
update changelog
raydouglass Aug 17, 2022
dab48ca
Merge pull request #969 from rapidsai/branch-22.08
raydouglass Aug 17, 2022
9a61ce5
REL v22.08.00 release
GPUtester Aug 17, 2022
62a1ee8
update changelog
raydouglass Oct 12, 2022
d7c6750
Merge pull request #1008 from rapidsai/branch-22.10
raydouglass Oct 12, 2022
382e519
REL v22.10.00 release
GPUtester Oct 12, 2022
bc7ec70
Merge pull request #1059 from rapidsai/branch-22.12
raydouglass Dec 8, 2022
dc4758e
REL v22.12.00 release
GPUtester Dec 8, 2022
7664dbd
Merge pull request #1109 from rapidsai/branch-23.02
raydouglass Feb 9, 2023
748bccd
REL v23.02.00 release
raydouglass Feb 9, 2023
575cc6a
Merge pull request #1124 from rapidsai/branch-23.02
raydouglass Feb 22, 2023
2c50668
REL v23.02.01 release
raydouglass Feb 22, 2023
a301937
Merge pull request #1158 from rapidsai/branch-23.04
raydouglass Apr 12, 2023
d4d6a02
REL v23.04.00 release
raydouglass Apr 12, 2023
c55bb7f
REL Merge pull request #1170 from rapidsai/branch-23.04
raydouglass May 3, 2023
ec3186d
Merge pull request #1188 from rapidsai/branch-23.06
raydouglass Jun 7, 2023
fd3ab2d
REL v23.06.00 release
raydouglass Jun 7, 2023
d8d6ccc
Merge pull request #1217 from rapidsai/branch-23.08
raydouglass Aug 9, 2023
efbd6ca
REL v23.08.00 release
raydouglass Aug 9, 2023
e52d438
Merge remote-tracking branch 'upstream/main' into shuffle-parquet
rjzamora Aug 28, 2023
5ca317c
add shuffle_to_parquet
rjzamora Aug 28, 2023
2df6f91
add pre_shuffle callback
rjzamora Aug 28, 2023
22697f6
add more options (experimental)
rjzamora Aug 29, 2023
24336fe
Merge remote-tracking branch 'upstream/branch-24.02' into shuffle-par…
rjzamora Dec 12, 2023
f67966f
basic to_worker_storage and from_worker_storage API
rjzamora Dec 13, 2023
cd037f5
Merge branch 'branch-24.02' into shuffle-parquet
rjzamora Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 173 additions & 1 deletion dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import asyncio
import contextlib
import functools
import inspect
import threading
import uuid
from collections import defaultdict
from math import ceil
from operator import getitem
Expand All @@ -16,18 +19,55 @@
from dask.base import tokenize
from dask.dataframe.core import DataFrame, Series, _concat as dd_concat, new_dd_object
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from distributed import wait
from distributed import get_worker, wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker

from .. import comms

try:
from tqdm import tqdm
except ImportError:
tqdm = lambda x: x

T = TypeVar("T")


Proxify = Callable[[T], T]


_WORKER_CACHE = {}
_WORKER_CACHE_LOCK = threading.RLock()


@contextlib.contextmanager
def get_worker_cache(name):
with _WORKER_CACHE_LOCK:
yield _get_worker_cache(name)


def _get_worker_cache(name):
"""Utility to get the `name` element of the cache
dictionary for the current worker. If executed
by anything other than a distributed Dask worker,
we will use the global `_WORKER_CACHE` variable.
"""
try:
worker = get_worker()
except ValueError:
# There is no dask.distributed worker.
# Assume client/worker are same process
global _WORKER_CACHE # pylint: disable=global-variable-not-assigned
if name not in _WORKER_CACHE:
_WORKER_CACHE[name] = {}
return _WORKER_CACHE[name]
if not hasattr(worker, "worker_cache"):
worker.worker_cache = {}
if name not in worker.worker_cache:
worker.worker_cache[name] = {}
return worker.worker_cache[name]


def get_proxify(worker: Worker) -> Proxify:
"""Get function to proxify objects"""
from dask_cuda.proxify_host_file import ProxifyHostFile
Expand Down Expand Up @@ -328,6 +368,8 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
parquet_dir: str | None,
final_task: bool,
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task

Expand Down Expand Up @@ -371,6 +413,8 @@ async def shuffle_task(
assert stage.keys() == rank_to_inkeys[myrank]
no_comm_postprocess = get_no_comm_postprocess(stage, num_rounds, batchsize, proxify)

fns = []
append_files = True # Whether to keep files open between batches
out_part_id_to_dataframe_list: Dict[int, List[DataFrame]] = defaultdict(list)
for _ in range(num_rounds):
partitions = create_partitions(
Expand All @@ -386,6 +430,50 @@ async def shuffle_task(
out_part_id_to_dataframe_list,
)

if parquet_dir:
import cudf

out_part_ids = list(out_part_id_to_dataframe_list.keys())
for out_part_id in out_part_ids:
if append_files:
writers = _get_worker_cache("writers")
try:
writer = writers[out_part_id]
except KeyError:
fn = f"{parquet_dir}/part.{out_part_id}.parquet"
fns.append(fn)
writer = cudf.io.parquet.ParquetWriter(fn, index=False)
writers[out_part_id] = writer
dfs = out_part_id_to_dataframe_list.pop(out_part_id)
dfs = [df for df in dfs if len(dfs) > 0]
for df in dfs:
writer.write_table(df)
del dfs
else:
dfs = out_part_id_to_dataframe_list.pop(out_part_id)
id = str(uuid.uuid4())[:8]
fn = f"{parquet_dir}/part.{out_part_id}.{id}.parquet"
fns.append(fn)
dfs = [df for df in dfs if len(dfs) > 0]
if len(dfs) > 1:
with cudf.io.parquet.ParquetWriter(fn, index=False) as writer:
for df in dfs:
writer.write_table(df)
elif dfs:
dfs[0].to_parquet(fn, index=False)
del dfs
await asyncio.sleep(0)

if parquet_dir:
if append_files:
if final_task:
for out_part_id in list(writers.keys()):
writers.pop(out_part_id).close()
await asyncio.sleep(0)
del writers
return {i: fn for i, fn in enumerate(fns)}
return {i: fn for i, fn in enumerate(fns)}

# Finally, we concatenate the output dataframes into the final output partitions
ret = {}
while out_part_id_to_dataframe_list:
Expand Down Expand Up @@ -518,6 +606,7 @@ def shuffle(
ignore_index,
num_rounds,
batchsize,
True,
)
wait(list(shuffle_result.values()))

Expand Down Expand Up @@ -547,6 +636,89 @@ def shuffle(
return ret


def shuffle_to_parquet(
full_df: DataFrame,
column_names: List[str],
parquet_dir: str,
npartitions: Optional[int] = None,
ignore_index: bool = False,
batchsize: int = 2,
pre_shuffle: Optional[int] = None,
overwrite: bool = False,
) -> None:
from dask_cuda.explicit_comms.dataframe.utils import (
_clean_worker_storage,
_prepare_dir,
)

c = comms.default_comms()

# Assume we are writing to local worker storage
if overwrite:
wait(c.client.run(_clean_worker_storage, parquet_dir))
wait(c.client.run(_prepare_dir, parquet_dir))

# The ranks of the output workers
ranks = list(range(len(c.worker_addresses)))

# By default, we preserve number of partitions
if npartitions is None:
npartitions = full_df.npartitions

# Find the output partition IDs for each worker
div = npartitions // len(ranks)
rank_to_out_part_ids: Dict[int, Set[int]] = {} # rank -> set of partition id
for i, rank in enumerate(ranks):
rank_to_out_part_ids[rank] = set(range(div * i, div * (i + 1)))
for rank, i in zip(ranks, range(div * len(ranks), npartitions)):
rank_to_out_part_ids[rank].add(i)

parts_per_batch = len(ranks) * batchsize
num_rounds = ceil(full_df.npartitions / parts_per_batch)
for stage in tqdm(range(num_rounds)):
offset = parts_per_batch * stage
df = full_df.partitions[offset : offset + parts_per_batch]

# Execute pre-shuffle function on each batch
if callable(pre_shuffle):
df = pre_shuffle(df)

df = df.persist() # Make sure optimizations are apply on the existing graph
wait([df]) # Make sure all keys has been materialized on workers
name = (
"explicit-comms-shuffle-"
f"{tokenize(df, column_names, npartitions, ignore_index)}"
)

# Stage all keys of `df` on the workers and cancel them, which makes it possible
# for the shuffle to free memory as the partitions of `df` are consumed.
# See CommsContext.stage_keys() for a description of staging.
rank_to_inkeys = c.stage_keys(name=name, keys=df.__dask_keys__())
max_num_inkeys = max(len(k) for k in rank_to_inkeys.values())
c.client.cancel(df)

# Run a shuffle task on each worker
shuffle_result = {}
for rank in ranks:
shuffle_result[rank] = c.submit(
c.worker_addresses[rank],
shuffle_task,
name,
rank_to_inkeys,
rank_to_out_part_ids,
column_names,
npartitions,
ignore_index,
1,
max_num_inkeys,
parquet_dir,
stage == (num_rounds - 1),
)
wait(list(shuffle_result.values()))

return


def _use_explicit_comms() -> bool:
"""Is explicit-comms and available?"""
if dask.config.get("explicit-comms", False):
Expand Down
140 changes: 140 additions & 0 deletions dask_cuda/explicit_comms/dataframe/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import glob
import os
import pickle
from collections import defaultdict

from dask.blockwise import BlockIndex
from distributed import wait
from distributed.protocol import dask_deserialize, dask_serialize

from dask_cuda.explicit_comms import comms


class LazyLoad:
def __init__(self, path, index, **kwargs):
self.path = path
self.index = index
self.kwargs = kwargs

def pre_serialize(self):
"""Make the unloaded partition serializable"""
return self.load()

def load(self):
"""Load the partition into memory"""
import cudf

fn = glob.glob(f"{self.path}/*.{self.index}.parquet")
return cudf.read_parquet(fn, **self.kwargs)


@dask_serialize.register(LazyLoad)
def _serialize_unloaded(obj):
return None, [pickle.dumps(obj.pre_serialize())]


@dask_deserialize.register(LazyLoad)
def _deserialize_unloaded(header, frames):
return pickle.loads(frames[0])


def _prepare_dir(dirpath: str):
os.makedirs(dirpath, exist_ok=True)


def _clean_worker_storage(dirpath: str):
import shutil

if os.path.isdir(dirpath):
shutil.rmtree(dirpath)


def _write_partition(part, dirpath, index, token=None):
if token is None:
fn = f"{dirpath}/part.{index[0]}.parquet"
else:
fn = f"{dirpath}/part.{token}.{index[0]}.parquet"
part.to_parquet(fn)
return index


def _get_partition(dirpath, index):
return LazyLoad(dirpath, index)


def _get_metadata(dirpath, index):
import glob

import pyarrow.parquet as pq

import cudf

fn = glob.glob(f"{dirpath}/*.{index}.parquet")[0]
return cudf.DataFrame.from_arrow(
pq.ParquetFile(fn).schema.to_arrow_schema().empty_table()
)


def _load_partition(data):
if isinstance(data, LazyLoad):
data = data.load()
return data


def to_worker_storage(df, dirpath, shuffle_on=None, overwrite=False, **kwargs):

if shuffle_on:
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle_to_parquet

if not isinstance(shuffle_on, list):
shuffle_on = [shuffle_on]
return shuffle_to_parquet(
df, shuffle_on, dirpath, overwrite=overwrite, **kwargs
)

c = comms.default_comms()
if overwrite:
wait(c.client.run(_clean_worker_storage, dirpath))
wait(c.client.run(_prepare_dir, dirpath))
df.map_partitions(
_write_partition,
dirpath,
BlockIndex((df.npartitions,)),
**kwargs,
).compute()


def from_worker_storage(dirpath):
import dask_cudf

c = comms.default_comms()

def get_indices(path):
return {int(fn.split(".")[-2]) for fn in glob.glob(path + "/*.parquet")}

worker_indices = c.client.run(get_indices, dirpath)

summary = defaultdict(list)
for worker, indices in worker_indices.items():
for index in indices:
summary[index].append(worker)

assignments = {}
futures = []
meta = None
for i, (worker, indices) in enumerate(summary.items()):
assignments[worker] = indices[i % len(indices)]
futures.append(
c.client.submit(_get_partition, dirpath, i, workers=[assignments[i]])
)
if meta is None:
meta = c.client.submit(_get_metadata, dirpath, i, workers=[assignments[i]])
wait(meta)
meta = meta.result()
wait(futures)

return dask_cudf.from_delayed(futures, meta=meta, verify_meta=False).map_partitions(
_load_partition,
meta=meta,
enforce_metadata=False,
)