Skip to content

Commit

Permalink
Support datasets with differently chunked variables in DatasetToChunks
Browse files Browse the repository at this point in the history
There are two major internal changes:
1. Key objects from DatasetToChunks now can include different dimensions for different variables when using split_vars=True. This makes it easier to handle large datasets with many variables and different chunking per variable.
2. Inputs inside the DatasetToChunks pipeline can now be sharded across many tasks. This is important for scalability to large datasets, especially with this chagne because the above refactor increases the number of inputs by the number of variables when split_vars=True. Otherwise, we can run into performance issues on the machine launching the pipeline when the number of inputs goes into the millions (e.g., slow speed, out of memory).

See the new integration test for a concrete use-case, resembling real model output.

Also revise the warning message in the README to be a bit friendlier.

Fixes #43

PiperOrigin-RevId: 471948735
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Sep 3, 2022
1 parent 4dced4e commit 59c74cc
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- name: Cancel previous
uses: styfle/[email protected]
Expand Down
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ multi-dimensional labeled arrays, such as:
For more about our approach and how to get started,
**[read the documentation](https://xarray-beam.readthedocs.io/)**!

**🚨 Warning: Xarray-Beam is new and unpolished 🚨**
**Warning: Xarray-Beam is a sharp tool 🔪**

Expect sharp edges 🔪 and performance cliffs 🧗, particularly related to the
management of lazy data with Dask and reading/writing data with Zarr. We have
used it to efficiently process ~25 TB datasets. We _expect_ it to scale to PB
size datasets, but that's easier said than done. We welcome feedback and
contributions from early adopters, and hope to have it ready for wider audience
soon.
Xarray-Beam is relatively new, and focused on expert users:

- We use it extensively at Google for processing large-scale weather datasets,
but there is not yet a vibrant external community.
- It provides low-level abstractions that facilitate writing very large
scale data pipelines (e.g., 100+ TB), but by design it requires explicitly
thinking about how every operation is parallelized.

## Installation

Xarray-Beam requires recent versions of immutabledict, xarray, dask, rechunker
and zarr, and the *latest* release of Apache Beam (2.31.0 or later). For best
performance when writing Zarr files, use Xarray 0.19.0 or later.
Xarray-Beam requires recent versions of immutabledict, Xarray, Dask, Rechunker,
Zarr, and Apache Beam. For best performance when writing Zarr files, use Xarray
0.19.0 or later.

## Disclaimer

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

setuptools.setup(
name='xarray-beam',
version='0.3.1',
version='0.4.0',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand All @@ -52,6 +52,6 @@
'docs': docs_requires,
},
url='https://github.com/google/xarray-beam',
packages=setuptools.find_packages(exclude=["examples"]),
packages=setuptools.find_packages(exclude=['examples']),
python_requires='>=3',
)
8 changes: 4 additions & 4 deletions xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
"""Public API for Xarray-Beam."""

# pylint: disable=g-multiple-import
from xarray_beam._src.combiners import (
MeanCombineFn,
)
from xarray_beam._src.core import (
Key,
DatasetToChunks,
ValidateEachChunk,
offsets_to_slices,
validate_chunk
)
from xarray_beam._src.combiners import (
MeanCombineFn,
)
from xarray_beam._src.rechunk import (
ConsolidateChunks,
ConsolidateVariables,
Expand All @@ -43,4 +43,4 @@
)
from xarray_beam import Mean

__version__ = '0.3.1'
__version__ = '0.4.0'
138 changes: 117 additions & 21 deletions xarray_beam/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Core data model for xarray-beam."""
import itertools
import math
from typing import (
AbstractSet,
Container,
Expand Down Expand Up @@ -196,16 +197,16 @@ def _chunks_to_offsets(


def iter_chunk_keys(
chunks: Mapping[str, Tuple[int, ...]],
offsets: Mapping[str, Sequence[int]],
vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin
) -> Iterator[Key]:
"""Iterate over the Key objects corresponding to the given chunks."""
all_offsets = _chunks_to_offsets(chunks)
chunk_indices = [range(len(sizes)) for sizes in chunks.values()]
chunk_indices = [range(len(sizes)) for sizes in offsets.values()]
for indices in itertools.product(*chunk_indices):
offsets = {
dim: all_offsets[dim][index] for dim, index in zip(chunks, indices)
key_offsets = {
dim: offsets[dim][index] for dim, index in zip(offsets, indices)
}
yield Key(offsets)
yield Key(key_offsets, vars)


def compute_offset_index(
Expand Down Expand Up @@ -262,6 +263,7 @@ def __init__(
chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None,
split_vars: bool = False,
num_threads: Optional[int] = None,
shard_keys_threshold: int = 200_000,
):
"""Initialize DatasetToChunks.
Expand All @@ -271,44 +273,138 @@ def __init__(
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variables or to keep all data variables together.
data variable or to keep all data variables together.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
usage and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already loaded in parallel, so this is most
useful for Datasets with a small number of variables.
shard_keys_threshold: threshold at which to compute keys on Beam workers,
rather than only on the host process. This is important for scaling
pipelines to millions of tasks.
"""
if chunks is None:
chunks = dataset.chunks
if chunks is None:
raise ValueError('dataset must be chunked or chunks must be set')
chunks = normalize_expanded_chunks(chunks, dataset.sizes)
raise ValueError('dataset must be chunked or chunks must be provided')
expanded_chunks = normalize_expanded_chunks(chunks, dataset.sizes)
self.dataset = dataset
self.chunks = chunks
self.expanded_chunks = expanded_chunks
self.split_vars = split_vars
self.num_threads = num_threads
self.offset_index = compute_offset_index(_chunks_to_offsets(chunks))
self.shard_keys_threshold = shard_keys_threshold
# TODO(shoyer): consider recalculating these potentially large properties on
# each worker, rather than only once on the host.
self.offsets = _chunks_to_offsets(expanded_chunks)
self.offset_index = compute_offset_index(self.offsets)
# We use the simple heuristic of only sharding inputs along the dimension
# with the most chunks.
lengths = {k: len(v) for k, v in self.offsets.items()}
self.sharded_dim = max(lengths, key=lengths.get) if lengths else None
self.shard_count = self._shard_count()

def _task_count(self) -> int:
"""Count the number of tasks emitted by this transform."""
counts = {k: len(v) for k, v in self.expanded_chunks.items()}
if not self.split_vars:
return int(np.prod(list(counts.values())))
total = 0
for variable in self.dataset.values():
count_list = [v for k, v in counts.items() if k in variable.dims]
total += int(np.prod(count_list))
return total

def _shard_count(self) -> Optional[int]:
"""Determine the number of times to shard input keys."""
task_count = self._task_count()
if task_count <= self.shard_keys_threshold:
return None # no sharding

if not self.split_vars:
return math.ceil(task_count / self.shard_keys_threshold)

var_count = sum(
self.sharded_dim in var.dims for var in self.dataset.values()
)
return math.ceil(task_count / (var_count * self.shard_keys_threshold))

def _iter_all_keys(self) -> Iterator[Key]:
"""Iterate over all Key objects."""
if not self.split_vars:
yield from iter_chunk_keys(self.offsets)
else:
for name, variable in self.dataset.items():
relevant_offsets = {
k: v for k, v in self.offsets.items() if k in variable.dims
}
yield from iter_chunk_keys(relevant_offsets, vars={name})

def _iter_shard_keys(
self, shard_id: Optional[int], var_name: Optional[str]
) -> Iterator[Key]:
"""Iterate over Key objects for a specific shard and variable."""
if var_name is None:
offsets = self.offsets
else:
offsets = {
dim: self.offsets[dim] for dim in self.dataset[var_name].dims
}

if shard_id is None:
assert self.split_vars
yield from iter_chunk_keys(offsets, vars={var_name})
else:
assert self.split_vars == (var_name is not None)
dim = self.sharded_dim
count = math.ceil(len(self.offsets[dim]) / self.shard_count)
dim_slice = slice(shard_id * count, (shard_id + 1) * count)
offsets = {**offsets, dim: offsets[dim][dim_slice]}
vars_ = {var_name} if self.split_vars else None
yield from iter_chunk_keys(offsets, vars=vars_)

def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
"""Create inputs for sharded key iterators."""
if not self.split_vars:
return [(i, None) for i in range(self.shard_count)]

inputs = []
for name, variable in self.dataset.items():
if self.sharded_dim in variable.dims:
inputs.extend([(i, name) for i in range(self.shard_count)])
else:
inputs.append((None, name))
return inputs

def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, xarray.Dataset]]:
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
sizes = {
dim: self.chunks[dim][self.offset_index[dim][offset]]
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
for dim, offset in key.offsets.items()
}
slices = offsets_to_slices(key.offsets, sizes)
chunk = self.dataset.isel(slices)
dataset = self.dataset if key.vars is None else self.dataset[list(key.vars)]
chunk = dataset.isel(slices)
# Load the data, using a separate thread for each variable
num_threads = len(self.dataset.data_vars)
num_threads = len(self.dataset)
result = chunk.chunk().compute(num_workers=num_threads)
if self.split_vars:
for k in result:
yield key.replace(vars={k}), result[[k]]
else:
yield key, result
yield key, result

def expand(self, pcoll):
if self.shard_count is None:
# Create all keys on the machine launching the Beam pipeline. This is
# faster if the number of keys is small.
key_pcoll = pcoll | beam.Create(self._iter_all_keys())
else:
# Create keys in separate shards on Beam workers. This is more scalable.
key_pcoll = (
pcoll
| beam.Create(self._shard_inputs())
| beam.FlatMapTuple(self._iter_shard_keys)
| beam.Reshuffle()
)

return (
pcoll
| beam.Create(iter_chunk_keys(self.chunks))
key_pcoll
| threadmap.FlatThreadMap(
self._key_to_chunks, num_threads=self.num_threads
)
Expand Down
Loading

0 comments on commit 59c74cc

Please sign in to comment.