Skip to content

Commit

Permalink
Use multi-stage rechunking.
Browse files Browse the repository at this point in the history
As described in pangeo-data/rechunker#89,
this can yield significant performance benefits for rechunking large
arrays.

PiperOrigin-RevId: 518325665
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Mar 21, 2023
1 parent 443aeae commit 9380408
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/era5_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path')
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner')


# pylint: disable=expression-not-assigned
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'apache_beam>=2.31.0',
'dask',
'immutabledict',
'rechunker',
'rechunker>=0.5.1',
'zarr',
'xarray',
]
Expand All @@ -42,7 +42,7 @@

setuptools.setup(
name='xarray-beam',
version='0.5.1',
version='0.6.0',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@
DatasetToZarr,
)

__version__ = '0.5.1'
__version__ = '0.6.0'
40 changes: 26 additions & 14 deletions xarray_beam/_src/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
import itertools
import logging
import math
import textwrap
from typing import (
Any,
Expand Down Expand Up @@ -75,17 +76,22 @@ def rechunking_plan(
source_chunks: Mapping[str, int],
target_chunks: Mapping[str, int],
itemsize: int,
min_mem: int,
max_mem: int,
) -> List[Dict[str, int]]:
) -> List[List[Dict[str, int]]]:
"""Make a rechunking plan."""
plan_shapes = algorithm.rechunking_plan(
stages = algorithm.multistage_rechunking_plan(
shape=tuple(dim_sizes.values()),
source_chunks=tuple(source_chunks[dim] for dim in dim_sizes),
target_chunks=tuple(target_chunks[dim] for dim in dim_sizes),
itemsize=itemsize,
min_mem=min_mem,
max_mem=max_mem,
)
return [dict(zip(dim_sizes.keys(), shapes)) for shapes in plan_shapes]
plan = []
for stage in stages:
plan.append([dict(zip(dim_sizes.keys(), shapes)) for shapes in stage])
return plan


def _consolidate_chunks_in_var_group(
Expand Down Expand Up @@ -511,7 +517,8 @@ def __init__(
source_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
target_chunks: Mapping[str, Union[int, Tuple[int, ...]]],
itemsize: int,
max_mem: int = 2**30, # 1 GB
min_mem: Optional[int] = None,
max_mem: int = 2 ** 30, # 1 GB
):
"""Initialize Rechunk().
Expand All @@ -524,13 +531,16 @@ def __init__(
itemsize: approximate number of bytes per xarray.Dataset element, after
indexing out by all dimensions, e.g., `4 * len(dataset)` for float32
data or roughly `dataset.nbytes / np.prod(dataset.sizes)`.
min_mem: minimum memory that a single intermediate chunk must consume.
max_mem: maximum memory that a single intermediate chunk may consume.
"""
if source_chunks.keys() != target_chunks.keys():
raise ValueError(
'source_chunks and target_chunks have different keys: '
f'{source_chunks} vs {target_chunks}'
)
if min_mem is None:
min_mem = max_mem // 100
self.dim_sizes = dim_sizes
self.source_chunks = normalize_chunks(source_chunks, dim_sizes)
self.target_chunks = normalize_chunks(target_chunks, dim_sizes)
Expand All @@ -539,27 +549,29 @@ def __init__(
self.source_chunks,
self.target_chunks,
itemsize=itemsize,
min_mem=min_mem,
max_mem=max_mem,
)
self.read_chunks, self.intermediate_chunks, self.write_chunks = plan
plan = (
[[self.source_chunks, self.source_chunks, plan[0][0]]]
+ plan
+ [[plan[-1][-1], self.target_chunks, self.target_chunks]]
)
self.stage_in, (_, *intermediates, _), self.stage_out = zip(*plan)

# TODO(shoyer): multi-stage rechunking, when supported by rechunker:
# https://github.com/pangeo-data/rechunker/pull/89
self.stage_in = [self.source_chunks, self.read_chunks, self.write_chunks]
self.stage_out = [self.read_chunks, self.write_chunks, self.target_chunks]
logging.info(
'Rechunking plan:\n'
+ '\n'.join(
f'{s} -> {t}' for s, t in zip(self.stage_in, self.stage_out)
f'Stage{i}: {s} -> {t}'
for i, (s, t) in enumerate(zip(self.stage_in, self.stage_out))
)
)
min_size = itemsize * np.prod(list(self.intermediate_chunks.values()))
min_size = min(
itemsize * math.prod(chunks.values()) for chunks in intermediates
)
logging.info(f'Smallest intermediates have size {min_size:1.3e}')

def expand(self, pcoll):
# TODO(shoyer): consider splitting xarray.Dataset objects into separate
# arrays for rechunking, which is more similar to what Rechunker does and
# in principle could be more efficient.
for stage, (in_chunks, out_chunks) in enumerate(
zip(self.stage_in, self.stage_out)
):
Expand Down
17 changes: 15 additions & 2 deletions xarray_beam/_src/rechunk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,24 @@ def test_normalize_chunks_errors(self):

def test_rechunking_plan(self):
# this trivial case fits entirely into memory
plan = rechunk.rechunking_plan(
plan, = rechunk.rechunking_plan(
dim_sizes={'x': 10, 'y': 20},
source_chunks={'x': 1, 'y': 20},
target_chunks={'x': 10, 'y': 1},
itemsize=1,
min_mem=0,
max_mem=200,
)
expected = [{'x': 10, 'y': 20}] * 3
self.assertEqual(plan, expected)

# this harder case doesn't
read_chunks, _, write_chunks = rechunk.rechunking_plan(
(read_chunks, _, write_chunks), = rechunk.rechunking_plan(
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
source_chunks={'t': 1, 'x': 200, 'y': 300},
target_chunks={'t': 1000, 'x': 20, 'y': 20},
itemsize=8,
min_mem=0,
max_mem=10_000_000,
)
self.assertGreater(read_chunks['t'], 1)
Expand All @@ -88,6 +90,17 @@ def test_rechunking_plan(self):
self.assertGreater(read_chunks['x'], 20)
self.assertGreater(read_chunks['y'], 20)

# multiple stages
stages = rechunk.rechunking_plan(
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
source_chunks={'t': 1, 'x': 200, 'y': 300},
target_chunks={'t': 1000, 'x': 20, 'y': 20},
itemsize=8,
min_mem=1_000_000,
max_mem=10_000_000,
)
self.assertGreater(len(stages), 1)

def test_consolidate_and_split_chunks(self):
consolidated = [
(
Expand Down

0 comments on commit 9380408

Please sign in to comment.