diff --git a/examples/era5_rechunk.py b/examples/era5_rechunk.py index f369e5d..0527830 100644 --- a/examples/era5_rechunk.py +++ b/examples/era5_rechunk.py @@ -20,7 +20,7 @@ INPUT_PATH = flags.DEFINE_string('input_path', None, help='Input Zarr path') OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path') -RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner') +RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner') # pylint: disable=expression-not-assigned diff --git a/setup.py b/setup.py index e87ea6e..1a167d7 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ 'apache_beam>=2.31.0', 'dask', 'immutabledict', - 'rechunker', + 'rechunker>=0.5.1', 'zarr', 'xarray', ] @@ -42,7 +42,7 @@ setuptools.setup( name='xarray-beam', - version='0.5.1', + version='0.6.0', license='Apache 2.0', author='Google LLC', author_email='noreply@google.com', diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 6a9c224..51c7f76 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -45,4 +45,4 @@ DatasetToZarr, ) -__version__ = '0.5.1' +__version__ = '0.6.0' diff --git a/xarray_beam/_src/rechunk.py b/xarray_beam/_src/rechunk.py index ef5a710..2f625e8 100644 --- a/xarray_beam/_src/rechunk.py +++ b/xarray_beam/_src/rechunk.py @@ -16,6 +16,7 @@ import dataclasses import itertools import logging +import math import textwrap from typing import ( Any, @@ -75,17 +76,22 @@ def rechunking_plan( source_chunks: Mapping[str, int], target_chunks: Mapping[str, int], itemsize: int, + min_mem: int, max_mem: int, -) -> List[Dict[str, int]]: +) -> List[List[Dict[str, int]]]: """Make a rechunking plan.""" - plan_shapes = algorithm.rechunking_plan( + stages = algorithm.multistage_rechunking_plan( shape=tuple(dim_sizes.values()), source_chunks=tuple(source_chunks[dim] for dim in dim_sizes), target_chunks=tuple(target_chunks[dim] for dim in dim_sizes), itemsize=itemsize, + min_mem=min_mem, max_mem=max_mem, ) - return [dict(zip(dim_sizes.keys(), shapes)) for shapes in plan_shapes] + plan = [] + for stage in stages: + plan.append([dict(zip(dim_sizes.keys(), shapes)) for shapes in stage]) + return plan def _consolidate_chunks_in_var_group( @@ -511,7 +517,8 @@ def __init__( source_chunks: Mapping[str, Union[int, Tuple[int, ...]]], target_chunks: Mapping[str, Union[int, Tuple[int, ...]]], itemsize: int, - max_mem: int = 2**30, # 1 GB + min_mem: Optional[int] = None, + max_mem: int = 2 ** 30, # 1 GB ): """Initialize Rechunk(). @@ -524,6 +531,7 @@ def __init__( itemsize: approximate number of bytes per xarray.Dataset element, after indexing out by all dimensions, e.g., `4 * len(dataset)` for float32 data or roughly `dataset.nbytes / np.prod(dataset.sizes)`. + min_mem: minimum memory that a single intermediate chunk must consume. max_mem: maximum memory that a single intermediate chunk may consume. """ if source_chunks.keys() != target_chunks.keys(): @@ -531,6 +539,8 @@ def __init__( 'source_chunks and target_chunks have different keys: ' f'{source_chunks} vs {target_chunks}' ) + if min_mem is None: + min_mem = max_mem // 100 self.dim_sizes = dim_sizes self.source_chunks = normalize_chunks(source_chunks, dim_sizes) self.target_chunks = normalize_chunks(target_chunks, dim_sizes) @@ -539,27 +549,29 @@ def __init__( self.source_chunks, self.target_chunks, itemsize=itemsize, + min_mem=min_mem, max_mem=max_mem, ) - self.read_chunks, self.intermediate_chunks, self.write_chunks = plan + plan = ( + [[self.source_chunks, self.source_chunks, plan[0][0]]] + + plan + + [[plan[-1][-1], self.target_chunks, self.target_chunks]] + ) + self.stage_in, (_, *intermediates, _), self.stage_out = zip(*plan) - # TODO(shoyer): multi-stage rechunking, when supported by rechunker: - # https://github.com/pangeo-data/rechunker/pull/89 - self.stage_in = [self.source_chunks, self.read_chunks, self.write_chunks] - self.stage_out = [self.read_chunks, self.write_chunks, self.target_chunks] logging.info( 'Rechunking plan:\n' + '\n'.join( - f'{s} -> {t}' for s, t in zip(self.stage_in, self.stage_out) + f'Stage{i}: {s} -> {t}' + for i, (s, t) in enumerate(zip(self.stage_in, self.stage_out)) ) ) - min_size = itemsize * np.prod(list(self.intermediate_chunks.values())) + min_size = min( + itemsize * math.prod(chunks.values()) for chunks in intermediates + ) logging.info(f'Smallest intermediates have size {min_size:1.3e}') def expand(self, pcoll): - # TODO(shoyer): consider splitting xarray.Dataset objects into separate - # arrays for rechunking, which is more similar to what Rechunker does and - # in principle could be more efficient. for stage, (in_chunks, out_chunks) in enumerate( zip(self.stage_in, self.stage_out) ): diff --git a/xarray_beam/_src/rechunk_test.py b/xarray_beam/_src/rechunk_test.py index 475b41a..e6f85d8 100644 --- a/xarray_beam/_src/rechunk_test.py +++ b/xarray_beam/_src/rechunk_test.py @@ -63,22 +63,24 @@ def test_normalize_chunks_errors(self): def test_rechunking_plan(self): # this trivial case fits entirely into memory - plan = rechunk.rechunking_plan( + plan, = rechunk.rechunking_plan( dim_sizes={'x': 10, 'y': 20}, source_chunks={'x': 1, 'y': 20}, target_chunks={'x': 10, 'y': 1}, itemsize=1, + min_mem=0, max_mem=200, ) expected = [{'x': 10, 'y': 20}] * 3 self.assertEqual(plan, expected) # this harder case doesn't - read_chunks, _, write_chunks = rechunk.rechunking_plan( + (read_chunks, _, write_chunks), = rechunk.rechunking_plan( dim_sizes={'t': 1000, 'x': 200, 'y': 300}, source_chunks={'t': 1, 'x': 200, 'y': 300}, target_chunks={'t': 1000, 'x': 20, 'y': 20}, itemsize=8, + min_mem=0, max_mem=10_000_000, ) self.assertGreater(read_chunks['t'], 1) @@ -88,6 +90,17 @@ def test_rechunking_plan(self): self.assertGreater(read_chunks['x'], 20) self.assertGreater(read_chunks['y'], 20) + # multiple stages + stages = rechunk.rechunking_plan( + dim_sizes={'t': 1000, 'x': 200, 'y': 300}, + source_chunks={'t': 1, 'x': 200, 'y': 300}, + target_chunks={'t': 1000, 'x': 20, 'y': 20}, + itemsize=8, + min_mem=1_000_000, + max_mem=10_000_000, + ) + self.assertGreater(len(stages), 1) + def test_consolidate_and_split_chunks(self): consolidated = [ (