diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ceae79031f8..a85a946a72d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -49,7 +49,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -172,6 +172,69 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] +def _vindex_like( + da: DataArray, dim: Hashable, indexer: DataArray, like_ds: Dataset | None +) -> Variable: + """ + Apply a vectorized indexer, optionally matching the chunks of a datarray + of the same name in `like_ds`. This is useful for GroupBy binary ops. + This function is intended to be used with Dataset.map. + """ + # we want to use the fact that we know the chunksizes for the output (matches obj) + # so we can't just use Variable's indexing directly + array = da._variable._data + like_da = like_ds.get(da.name) + if not is_duck_dask_array(array): + if like_da is None or not is_duck_dask_array(like_da._variable._data): + # TODO: we should instead check of `shuffle` and `reshape_blockwise` + return da.isel({dim: indexer}) + else: + da = da.chunk("auto") + + var = da._variable + array = var._data + + from dask.array.core import slices_from_chunks + from dask.graph_manipulation import clone + + from xarray.core.dask_array_compat import reshape_blockwise + + array = clone(array) # FIXME: add to dask + + # dimensions for indexed result + out_dims = tuple( + itertools.chain( + *(indexer.dims if this == dim else (this,) for this in var.dims) + ) + ) + out_chunks = tuple( + da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims + ) + out_shape = tuple(var.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims) + idxr = indexer._variable._data + + # shuffle indices that can be reshaped blockwise to desired shape + core_dim_chunks = tuple( + chunks + for dim, chunks in zip(out_dims, out_chunks, strict=True) + if dim in indexer.dims + ) + flat_indices = [ + idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks) + ] + shuffled = var._shuffle(flat_indices, dim=dim, chunks="auto") + # shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks + new_chunksizes = dict(zip(var.dims, shuffled.chunks, strict=True)) + out_chunks = tuple( + new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims + ) + if shuffled.shape != out_shape: + out_data = reshape_blockwise(shuffled._data, shape=out_shape, chunks=out_chunks) + else: + out_data = shuffled._data + return Variable(out_dims, out_data, var.attrs) + + class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. @@ -899,25 +962,19 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - # if other is dask-backed, that's a hint that the - # "expanded" dataset is too big to hold in memory. - # this can be the case when `other` was read from disk - # and contains our lazy indexing classes - # We need to check for dask-backed Datasets - # so utils.is_duck_dask_array does not work for this check - if obj.chunks and not other.chunks: - # TODO: What about datasets with some dask vars, and others not? - # This handles dims other than `name`` - chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} - # a chunk size of 1 seems reasonable since we expect individual elements of - # other to be repeated multiple times across the reduced dimension(s) - chunks[name] = 1 - other = other.chunk(chunks) - # codes are defined for coord, so we align `other` with `coord` # before indexing other, _ = align(other, coord, join="right", copy=False) - expanded = other.isel({name: codes}) + + other_as_dataset = ( + other._to_temp_dataset() if isinstance(other, DataArray) else other + ) + obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj + expanded = other_as_dataset.map( + _vindex_like, dim=name, indexer=codes, like_ds=obj_as_dataset + ) + if isinstance(other, DataArray): + expanded = other._from_temp_dataset(expanded) result = g(obj, expanded) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e4383dd58a9..272f012564b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2654,12 +2654,14 @@ def test_groupby_math_auto_chunk() -> None: dims=("y", "x"), coords={"label": ("x", [2, 2, 1])}, ) + # da.groupby("label").min(...) sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) chunked = da.chunk(x=1, y=2) chunked.label.load() - actual = chunked.groupby("label") - sub + with raise_if_dask_computes(): + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}