From eee5e2a3cf128d0cb971f5ecd239ca2597da1563 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 7 Nov 2024 10:01:10 -0700 Subject: [PATCH 1/5] Use shuffle in groupby binary ops. xref #9546 Closes #9267 --- xarray/core/groupby.py | 83 +++++++++++++++++++++++++++++------- xarray/tests/test_groupby.py | 4 +- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ceae79031f8..810a056e4f7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -20,6 +20,7 @@ from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce +from xarray.core.computation import apply_ufunc from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.duck_array_ops import where @@ -49,7 +50,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -899,25 +900,75 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - # if other is dask-backed, that's a hint that the - # "expanded" dataset is too big to hold in memory. - # this can be the case when `other` was read from disk - # and contains our lazy indexing classes - # We need to check for dask-backed Datasets - # so utils.is_duck_dask_array does not work for this check - if obj.chunks and not other.chunks: - # TODO: What about datasets with some dask vars, and others not? - # This handles dims other than `name`` - chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} - # a chunk size of 1 seems reasonable since we expect individual elements of - # other to be repeated multiple times across the reduced dimension(s) - chunks[name] = 1 - other = other.chunk(chunks) + def _vindex_wrapper(array, idxr, like): + # we want to use the fact that we know the chunksizes for the output (matches obj) + # so we can't just use Variable's indexing + import dask + from dask.array.core import slices_from_chunks + from dask.graph_manipulation import clone + + array = clone(array) # FIXME: add to dask + + assert array.ndim == 1 + to_shape = like.shape[-1:] + to_chunks = like.chunks[-1:] + flat_indices = [ + idxr[slicer].ravel().tolist() + for slicer in slices_from_chunks(to_chunks) + ] + # FIXME: figure out axis + shuffled = dask.array.shuffle( + array, flat_indices, axis=array.ndim - 1, chunks="auto" + ) + if shuffled.shape != to_shape: + return dask.array.reshape_blockwise( + shuffled, shape=to_shape, chunks=to_chunks + ) + else: + return shuffled # codes are defined for coord, so we align `other` with `coord` # before indexing other, _ = align(other, coord, join="right", copy=False) - expanded = other.isel({name: codes}) + + other_as_dataset = ( + other._to_temp_dataset() if isinstance(other, DataArray) else other + ) + obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj + dask_vars = [] + non_dask_vars = [] + for varname, var in other_as_dataset._variables.items(): + if is_duck_dask_array(var._data): + dask_vars.append(varname) + else: + non_dask_vars.append(varname) + expanded = other_as_dataset[non_dask_vars].isel({name: codes}) + if dask_vars: + other_dims = other_as_dataset[dask_vars].dims + obj_dims = obj_as_dataset[dask_vars].dims + expanded = expanded.merge( + apply_ufunc( + _vindex_wrapper, + other_as_dataset[dask_vars], + codes, + obj_as_dataset[dask_vars], + input_core_dims=[ + tuple(other_dims), # FIXME: ..., name + tuple(codes.dims), + tuple(obj_dims), + ], + # When other is the result of a reduction over Ellipsis + # obj.dims is a superset of other.dims, and contains + # dims not present in the output + exclude_dims=set(obj_dims) - set(other_dims), + output_core_dims=[tuple(codes.dims)], + dask="allowed", + join=OPTIONS["arithmetic_join"], + ) + ) + + if isinstance(other, DataArray): + expanded = other._from_temp_dataset(expanded) result = g(obj, expanded) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e4383dd58a9..272f012564b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2654,12 +2654,14 @@ def test_groupby_math_auto_chunk() -> None: dims=("y", "x"), coords={"label": ("x", [2, 2, 1])}, ) + # da.groupby("label").min(...) sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) chunked = da.chunk(x=1, y=2) chunked.label.load() - actual = chunked.groupby("label") - sub + with raise_if_dask_computes(): + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} From 9294b9408ffe40aaec9bca24ccd9450876072613 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 7 Dec 2024 14:50:58 -0700 Subject: [PATCH 2/5] use map instead --- xarray/core/groupby.py | 66 ++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 810a056e4f7..ebbc427cdf0 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -20,7 +20,6 @@ from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.computation import apply_ufunc from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.duck_array_ops import where @@ -900,9 +899,20 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - def _vindex_wrapper(array, idxr, like): + def _vindex_like(da: DataArray, dim, indexer: DataArray): # we want to use the fact that we know the chunksizes for the output (matches obj) # so we can't just use Variable's indexing + + array = da._variable._data + like_da = obj_as_dataset.get(da.name) + if not is_duck_dask_array(array): + if like_da is None or not is_duck_dask_array(like_da._variable._data): + return da.isel({dim: indexer}) + else: + da = da.chunk("auto") + like = like_da._variable._data + array = da._variable._data + import dask from dask.array.core import slices_from_chunks from dask.graph_manipulation import clone @@ -910,22 +920,27 @@ def _vindex_wrapper(array, idxr, like): array = clone(array) # FIXME: add to dask assert array.ndim == 1 - to_shape = like.shape[-1:] - to_chunks = like.chunks[-1:] + dims = indexer.dims + axes = tuple(like_da.get_axis_num(dim) for dim in dims) + to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes) + to_chunks = tuple( + chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes + ) + idxr = indexer._variable._data + + # shuffle indices that can be reshaped blockwise to desired shape flat_indices = [ idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(to_chunks) ] - # FIXME: figure out axis shuffled = dask.array.shuffle( - array, flat_indices, axis=array.ndim - 1, chunks="auto" + array, flat_indices, axis=da.get_axis_num(dim), chunks="auto" ) if shuffled.shape != to_shape: - return dask.array.reshape_blockwise( + shuffled = dask.array.reshape_blockwise( shuffled, shape=to_shape, chunks=to_chunks ) - else: - return shuffled + return DataArray(dims=like_da.dims[-1:], data=shuffled, attrs=da.attrs) # codes are defined for coord, so we align `other` with `coord` # before indexing @@ -935,38 +950,7 @@ def _vindex_wrapper(array, idxr, like): other._to_temp_dataset() if isinstance(other, DataArray) else other ) obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj - dask_vars = [] - non_dask_vars = [] - for varname, var in other_as_dataset._variables.items(): - if is_duck_dask_array(var._data): - dask_vars.append(varname) - else: - non_dask_vars.append(varname) - expanded = other_as_dataset[non_dask_vars].isel({name: codes}) - if dask_vars: - other_dims = other_as_dataset[dask_vars].dims - obj_dims = obj_as_dataset[dask_vars].dims - expanded = expanded.merge( - apply_ufunc( - _vindex_wrapper, - other_as_dataset[dask_vars], - codes, - obj_as_dataset[dask_vars], - input_core_dims=[ - tuple(other_dims), # FIXME: ..., name - tuple(codes.dims), - tuple(obj_dims), - ], - # When other is the result of a reduction over Ellipsis - # obj.dims is a superset of other.dims, and contains - # dims not present in the output - exclude_dims=set(obj_dims) - set(other_dims), - output_core_dims=[tuple(codes.dims)], - dask="allowed", - join=OPTIONS["arithmetic_join"], - ) - ) - + expanded = other_as_dataset.map(_vindex_like, dim=name, indexer=codes) if isinstance(other, DataArray): expanded = other._from_temp_dataset(expanded) From 4aa6bb97418c831d7f41ce188b67581840f2558a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 8 Dec 2024 21:58:20 -0700 Subject: [PATCH 3/5] cleanup --- xarray/core/groupby.py | 103 +++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 44 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ebbc427cdf0..a39c6583981 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -172,6 +172,62 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] +def _vindex_like( + da: DataArray, dim: Hashable, indexer: DataArray, like_ds: Dataset | None +) -> Variable: + """ + Apply a vectorized indexer, optionally matching the chunks of a datarray + of the same name in `like_ds`. This is useful for GroupBy binary ops. + This function is intended to be used with Dataset.map. + """ + # we want to use the fact that we know the chunksizes for the output (matches obj) + # so we can't just use Variable's indexing directly + array = da._variable._data + like_da = like_ds.get(da.name) + if not is_duck_dask_array(array): + if like_da is None or not is_duck_dask_array(like_da._variable._data): + # TODO: we should instead check of `shuffle` and `reshape_blockwise` + return da.isel({dim: indexer}) + else: + da = da.chunk("auto") + + like = like_da._variable._data + array = da._variable._data + + import dask.array + from dask.array.core import slices_from_chunks + from dask.graph_manipulation import clone + + from xarray.core.dask_array_compat import reshape_blockwise + + array = clone(array) # FIXME: add to dask + + assert array.ndim == 1 + dims = indexer.dims + axes = tuple(like_da.get_axis_num(dim) for dim in dims) + to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes) + to_chunks = tuple( + chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes + ) + idxr = indexer._variable._data + + # dimensions for indexed result + out_dims = tuple( + itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims)) + ) + + # shuffle indices that can be reshaped blockwise to desired shape + flat_indices = [ + idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(to_chunks) + ] + shuffled = dask.array.shuffle( + array, flat_indices, axis=da.get_axis_num(dim), chunks="auto" + ) + if shuffled.shape != to_shape: + shuffled = reshape_blockwise(shuffled, shape=to_shape, chunks=to_chunks) + return Variable(dims=out_dims, data=shuffled, attrs=da.attrs) + + class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. @@ -899,49 +955,6 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - def _vindex_like(da: DataArray, dim, indexer: DataArray): - # we want to use the fact that we know the chunksizes for the output (matches obj) - # so we can't just use Variable's indexing - - array = da._variable._data - like_da = obj_as_dataset.get(da.name) - if not is_duck_dask_array(array): - if like_da is None or not is_duck_dask_array(like_da._variable._data): - return da.isel({dim: indexer}) - else: - da = da.chunk("auto") - like = like_da._variable._data - array = da._variable._data - - import dask - from dask.array.core import slices_from_chunks - from dask.graph_manipulation import clone - - array = clone(array) # FIXME: add to dask - - assert array.ndim == 1 - dims = indexer.dims - axes = tuple(like_da.get_axis_num(dim) for dim in dims) - to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes) - to_chunks = tuple( - chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes - ) - idxr = indexer._variable._data - - # shuffle indices that can be reshaped blockwise to desired shape - flat_indices = [ - idxr[slicer].ravel().tolist() - for slicer in slices_from_chunks(to_chunks) - ] - shuffled = dask.array.shuffle( - array, flat_indices, axis=da.get_axis_num(dim), chunks="auto" - ) - if shuffled.shape != to_shape: - shuffled = dask.array.reshape_blockwise( - shuffled, shape=to_shape, chunks=to_chunks - ) - return DataArray(dims=like_da.dims[-1:], data=shuffled, attrs=da.attrs) - # codes are defined for coord, so we align `other` with `coord` # before indexing other, _ = align(other, coord, join="right", copy=False) @@ -950,7 +963,9 @@ def _vindex_like(da: DataArray, dim, indexer: DataArray): other._to_temp_dataset() if isinstance(other, DataArray) else other ) obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj - expanded = other_as_dataset.map(_vindex_like, dim=name, indexer=codes) + expanded = other_as_dataset.map( + _vindex_like, dim=name, indexer=codes, like_ds=obj_as_dataset + ) if isinstance(other, DataArray): expanded = other._from_temp_dataset(expanded) From 633b36190c6d780a2769d18a465071c0a39d6345 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 8 Dec 2024 22:31:34 -0700 Subject: [PATCH 4/5] Fix for more broadcast dims --- xarray/core/groupby.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a39c6583981..fb5535e0128 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -191,7 +191,6 @@ def _vindex_like( else: da = da.chunk("auto") - like = like_da._variable._data array = da._variable._data import dask.array @@ -201,30 +200,37 @@ def _vindex_like( from xarray.core.dask_array_compat import reshape_blockwise array = clone(array) # FIXME: add to dask - - assert array.ndim == 1 - dims = indexer.dims - axes = tuple(like_da.get_axis_num(dim) for dim in dims) - to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes) - to_chunks = tuple( - chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes - ) - idxr = indexer._variable._data + # array = clone(array) # FIXME: add to dask # dimensions for indexed result out_dims = tuple( itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims)) ) + out_chunks = tuple( + da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims + ) + out_shape = tuple(da.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims) + idxr = indexer._variable._data # shuffle indices that can be reshaped blockwise to desired shape + core_dim_chunks = tuple( + chunks + for dim, chunks in zip(out_dims, out_chunks, strict=True) + if dim in indexer.dims + ) flat_indices = [ - idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(to_chunks) + idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks) ] shuffled = dask.array.shuffle( array, flat_indices, axis=da.get_axis_num(dim), chunks="auto" ) - if shuffled.shape != to_shape: - shuffled = reshape_blockwise(shuffled, shape=to_shape, chunks=to_chunks) + # shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks + new_chunksizes = dict(zip(da.dims, shuffled.chunks, strict=True)) + out_chunks = tuple( + new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims + ) + if shuffled.shape != out_shape: + shuffled = reshape_blockwise(shuffled, shape=out_shape, chunks=out_chunks) return Variable(dims=out_dims, data=shuffled, attrs=da.attrs) From 18c96eaf7f4e793ac0fcb515509a4c11026d16ab Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Dec 2024 21:36:01 -0700 Subject: [PATCH 5/5] Switch to Variable --- xarray/core/groupby.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index fb5535e0128..a85a946a72d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -191,25 +191,26 @@ def _vindex_like( else: da = da.chunk("auto") - array = da._variable._data + var = da._variable + array = var._data - import dask.array from dask.array.core import slices_from_chunks from dask.graph_manipulation import clone from xarray.core.dask_array_compat import reshape_blockwise array = clone(array) # FIXME: add to dask - # array = clone(array) # FIXME: add to dask # dimensions for indexed result out_dims = tuple( - itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims)) + itertools.chain( + *(indexer.dims if this == dim else (this,) for this in var.dims) + ) ) out_chunks = tuple( da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims ) - out_shape = tuple(da.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims) + out_shape = tuple(var.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims) idxr = indexer._variable._data # shuffle indices that can be reshaped blockwise to desired shape @@ -221,17 +222,17 @@ def _vindex_like( flat_indices = [ idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks) ] - shuffled = dask.array.shuffle( - array, flat_indices, axis=da.get_axis_num(dim), chunks="auto" - ) + shuffled = var._shuffle(flat_indices, dim=dim, chunks="auto") # shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks - new_chunksizes = dict(zip(da.dims, shuffled.chunks, strict=True)) + new_chunksizes = dict(zip(var.dims, shuffled.chunks, strict=True)) out_chunks = tuple( new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims ) if shuffled.shape != out_shape: - shuffled = reshape_blockwise(shuffled, shape=out_shape, chunks=out_chunks) - return Variable(dims=out_dims, data=shuffled, attrs=da.attrs) + out_data = reshape_blockwise(shuffled._data, shape=out_shape, chunks=out_chunks) + else: + out_data = shuffled._data + return Variable(out_dims, out_data, var.attrs) class _DummyGroup(Generic[T_Xarray]):