From f6c289212731ee2f0e3e8c72922d443941cb58ec Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 11 Dec 2025 12:09:38 -0500 Subject: [PATCH 1/8] fix group by > 1 group, then map --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 827c0a3588f..0fceb9c5b73 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -177,7 +177,9 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return None positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] - newpositions = nputils.inverse_permutation(np.concatenate(positions), N) + newpositions = nputils.inverse_permutation( + np.concatenate(tuple(p for p in positions if p)), N + ) return newpositions[newpositions != -1] From d4e203dd2369ac8ce44e789888ca380a273b3762 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 11 Dec 2025 12:27:51 -0500 Subject: [PATCH 2/8] Update xarray/core/groupby.py Co-authored-by: Deepak Cherian --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0fceb9c5b73..b2dca323c25 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -178,7 +178,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] newpositions = nputils.inverse_permutation( - np.concatenate(tuple(p for p in positions if p)), N + np.concatenate(tuple(p for p in positions if len(p) > 0)), N ) return newpositions[newpositions != -1] From 9d7c9ea1f249a498ce81112ba9559705ece50341 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Thu, 11 Dec 2025 12:37:01 -0500 Subject: [PATCH 3/8] add test --- xarray/tests/test_groupby.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 47ea2fcd2b0..d24284d38b3 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3849,6 +3849,19 @@ def test_groupby_bins_mean_time_series(): assert "measurement" in ds_agged.data_vars assert ds_agged.time.dtype == np.dtype("datetime64[ns]") +def test_groupby_multi_map(): + # https://github.com/pydata/xarray/issues/11004 + d = xr.DataArray( + [[0, 1], [2, 3]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["ny", "nx"], + ) + + d.groupby(('lon', 'lat')).map(lambda x: x) + # TODO: Possible property tests to add to this module # 1. lambda x: x From ac36a5aebc96d8465698702b9aa20d480eeb2e2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:37:36 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d24284d38b3..7b8b9db29f8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3849,6 +3849,7 @@ def test_groupby_bins_mean_time_series(): assert "measurement" in ds_agged.data_vars assert ds_agged.time.dtype == np.dtype("datetime64[ns]") + def test_groupby_multi_map(): # https://github.com/pydata/xarray/issues/11004 d = xr.DataArray( @@ -3860,7 +3861,7 @@ def test_groupby_multi_map(): dims=["ny", "nx"], ) - d.groupby(('lon', 'lat')).map(lambda x: x) + d.groupby(("lon", "lat")).map(lambda x: x) # TODO: Possible property tests to add to this module From 1d1230ab345485c3a6cf8a346c166a7a6cb22556 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 11 Dec 2025 20:51:51 -0700 Subject: [PATCH 5/8] Fix --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b2dca323c25..0e74ce65a01 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -173,7 +173,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray if isinstance(positions[0], slice): positions = _consolidate_slices(positions) - if positions == slice(None): + if positions == [slice(None)] or positions == [slice(0, None)]: return None positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] From 4b4e209f9ada29bdaa30d16ca8b51c9f4c938473 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Sat, 13 Dec 2025 14:51:23 -0500 Subject: [PATCH 6/8] add check for None --- xarray/core/groupby.py | 222 ++++++++++++++++++----------------- xarray/tests/test_groupby.py | 4 +- 2 files changed, 114 insertions(+), 112 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0e74ce65a01..989f9e63a9f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -38,7 +38,6 @@ from xarray.core.types import ( Dims, QuantileMethods, - T_DataArray, T_DataWithCoords, T_Xarray, ) @@ -143,10 +142,10 @@ def _consolidate_slices(slices: list[slice]) -> list[slice]: if not isinstance(slice_, slice): raise ValueError(f"list element is not a slice: {slice_!r}") if ( - result - and last_slice.stop == slice_.start - and _is_one_or_none(last_slice.step) - and _is_one_or_none(slice_.step) + result + and last_slice.stop == slice_.start + and _is_one_or_none(last_slice.step) + and _is_one_or_none(slice_.step) ): last_slice = slice(last_slice.start, slice_.stop, slice_.step) result[-1] = last_slice @@ -176,7 +175,6 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray if positions == [slice(None)] or positions == [slice(0, None)]: return None positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] - newpositions = nputils.inverse_permutation( np.concatenate(tuple(p for p in positions if len(p) > 0)), N ) @@ -213,7 +211,8 @@ def data(self) -> np.ndarray: return np.arange(self.size, dtype=int) def __array__( - self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, + copy: bool | None = None ) -> np.ndarray: if copy is False: raise NotImplementedError(f"An array copy is necessary, got {copy = }.") @@ -255,7 +254,7 @@ def to_array(self) -> DataArray: def _ensure_1d( - group: T_Group, obj: T_DataWithCoords + group: T_Group, obj: T_DataWithCoords ) -> tuple[ T_Group, T_DataWithCoords, @@ -349,7 +348,7 @@ def __post_init__(self) -> None: ) if not isinstance(self.group, _DummyGroup) and is_chunked_array( - self.group.variable._data + self.group.variable._data ): # This requires a pass to discover the groups present if isinstance(self.grouper, UniqueGrouper) and self.grouper.labels is None: @@ -358,7 +357,7 @@ def __post_init__(self) -> None: ) # this requires a pass to compute the bin edges if isinstance(self.grouper, BinGrouper) and isinstance( - self.grouper.bins, int + self.grouper.bins, int ): raise ValueError( "Please pass explicit bin edges to BinGrouper using the ``bins`` kwarg" @@ -385,11 +384,11 @@ def __len__(self) -> int: def _parse_group_and_groupers( - obj: T_Xarray, - group: GroupInput, - groupers: dict[str, Grouper], - *, - eagerly_compute_group: Literal[False] | None, + obj: T_Xarray, + group: GroupInput, + groupers: dict[str, Grouper], + *, + eagerly_compute_group: Literal[False] | None, ) -> tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray from xarray.groupers import Grouper, UniqueGrouper @@ -457,7 +456,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: def _resolve_group( - obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable + obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable ) -> T_Group: from xarray.core.dataarray import DataArray @@ -541,7 +540,8 @@ def factorize(self) -> EncodedGroups: # NaNs; as well as values outside the bins are coded by -1 # Restore these after the raveling broadcasted_masks = broadcast(*masks) - mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] + mask = functools.reduce(np.logical_or, + broadcasted_masks) # type: ignore[arg-type] _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( @@ -637,10 +637,10 @@ class GroupBy(Generic[T_Xarray]): encoded: EncodedGroups def __init__( - self, - obj: T_Xarray, - groupers: tuple[ResolvedGrouper, ...], - restore_coord_dims: bool = True, + self, + obj: T_Xarray, + groupers: tuple[ResolvedGrouper, ...], + restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -663,8 +663,8 @@ def __init__( self.encoded = grouper.encoded else: if any( - isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) - for grouper in groupers + isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) + for grouper in groupers ): raise NotImplementedError( "Grouping by multiple variables, one of which " @@ -783,24 +783,24 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: return unstacked # type: ignore[return-value] def map( - self, - func: Callable, - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> T_Xarray: raise NotImplementedError() def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> T_Xarray: raise NotImplementedError() @@ -984,15 +984,15 @@ def _maybe_reindex(self, combined): the correct order here. """ has_missing_groups = ( - self.encoded.unique_coord.size != self.encoded.full_index.size + self.encoded.unique_coord.size != self.encoded.full_index.size ) indexers = {} for grouper in self.groupers: index = combined._indexes.get(grouper.name, None) if (has_missing_groups and index is not None) or ( - len(self.groupers) > 1 - and not isinstance(grouper.full_index, pd.RangeIndex) - and not index.index.equals(grouper.full_index) + len(self.groupers) > 1 + and not isinstance(grouper.full_index, pd.RangeIndex) + and not (index is not None and index.index.equals(grouper.full_index)) ): indexers[grouper.name] = grouper.full_index if indexers: @@ -1024,17 +1024,17 @@ def _maybe_unstack(self, obj): grouper.name for grouper in self.groupers if isinstance(grouper.group, _DummyGroup) - and isinstance(grouper.grouper, UniqueGrouper) + and isinstance(grouper.grouper, UniqueGrouper) ] obj = obj.drop_vars(to_drop) return obj def _flox_reduce( - self, - dim: Dims, - keep_attrs: bool | None = None, - **kwargs: Any, + self, + dim: Dims, + keep_attrs: bool | None = None, + **kwargs: Any, ) -> T_Xarray: """Adaptor function that translates our groupby API to that of flox.""" import flox @@ -1045,7 +1045,8 @@ def _flox_reduce( obj = self._original_obj variables = ( {k: v.variable for k, v in obj.data_vars.items()} - if isinstance(obj, Dataset) # type: ignore[redundant-expr] # seems to be a mypy bug + if isinstance(obj, + Dataset) # type: ignore[redundant-expr] # seems to be a mypy bug else obj._coords ) @@ -1073,11 +1074,11 @@ def _flox_reduce( name: var for name, var in variables.items() if ( - not _is_numeric_aggregatable_dtype(var) - # this avoids dropping any levels of a MultiIndex, which raises - # a warning - and name not in midx_grouping_vars - and name not in obj.dims + not _is_numeric_aggregatable_dtype(var) + # this avoids dropping any levels of a MultiIndex, which raises + # a warning + and name not in midx_grouping_vars + and name not in obj.dims ) } else: @@ -1097,7 +1098,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.codes.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -1111,12 +1112,13 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in + parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") has_missing_groups = ( - self.encoded.unique_coord.size != self.encoded.full_index.size + self.encoded.unique_coord.size != self.encoded.full_index.size ) if self._by_chunked or has_missing_groups or kwargs.get("min_count", 0) > 0: # Xarray *always* returns np.nan when there are no observations in a group, @@ -1184,9 +1186,9 @@ def _flox_reduce( for name, var in variables.items(): dims_set = set(var.dims) if ( - dims_set <= set(parsed_dim) - and (dims_set & set(result.dims)) - and name not in result_variables + dims_set <= set(parsed_dim) + and (dims_set & set(result.dims)) + and name not in result_variables ): to_broadcast[name] = var for name, var in to_broadcast.items(): @@ -1231,14 +1233,14 @@ def fillna(self, value: Any) -> T_Xarray: return ops.fillna(self, value) def quantile( - self, - q: ArrayLike, - dim: Dims = None, - *, - method: QuantileMethods = "linear", - keep_attrs: bool | None = None, - skipna: bool | None = None, - interpolation: QuantileMethods | None = None, + self, + q: ArrayLike, + dim: Dims = None, + *, + method: QuantileMethods = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QuantileMethods | None = None, ) -> T_Xarray: """Compute the qth quantile over each array in the groups and concatenate them together into a new array. @@ -1360,10 +1362,10 @@ def quantile( q = np.asarray(q, dtype=np.float64) if ( - method == "linear" - and OPTIONS["use_flox"] - and contains_only_chunked_or_numpy(self._obj) - and module_available("flox", minversion="0.9.4") + method == "linear" + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + and module_available("flox", minversion="0.9.4") ): result = self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna @@ -1405,15 +1407,15 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: return ops.where_method(self, cond, other) def _first_or_last( - self, - op: Literal["first" | "last"], - skipna: bool | None, - keep_attrs: bool | None, + self, + op: Literal["first" | "last"], + skipna: bool | None, + keep_attrs: bool | None, ): if all( - isinstance(maybe_slice, slice) - and (maybe_slice.stop == maybe_slice.start + 1) - for maybe_slice in self.encoded.group_indices + isinstance(maybe_slice, slice) + and (maybe_slice.stop == maybe_slice.start + 1) + for maybe_slice in self.encoded.group_indices ): # NB. this is currently only used for reductions along an existing # dimension @@ -1421,9 +1423,9 @@ def _first_or_last( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) if ( - module_available("flox", minversion="0.10.0") - and OPTIONS["use_flox"] - and contains_only_chunked_or_numpy(self._obj) + module_available("flox", minversion="0.10.0") + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) ): import flox.xrdtypes @@ -1444,7 +1446,7 @@ def _first_or_last( return result def first( - self, skipna: bool | None = None, keep_attrs: bool | None = None + self, skipna: bool | None = None, keep_attrs: bool | None = None ) -> T_Xarray: """ Return the first element of each group along the group dimension @@ -1465,7 +1467,7 @@ def first( return self._first_or_last("first", skipna, keep_attrs) def last( - self, skipna: bool | None = None, keep_attrs: bool | None = None + self, skipna: bool | None = None, keep_attrs: bool | None = None ) -> T_Xarray: """ Return the last element of each group along the group dimension @@ -1560,11 +1562,11 @@ def lookup_order(dimension): return stacked def map( - self, - func: Callable[..., DataArray], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable[..., DataArray], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> DataArray: """Apply a function to each array in the group and concatenate them together into a new array. @@ -1654,15 +1656,15 @@ def _combine(self, applied, shortcut=False): return combined def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -1742,11 +1744,11 @@ def dims(self) -> Frozen[Hashable, int]: return FrozenMappingWarningOnValuesAccess(self._dims) def map( - self, - func: Callable[..., Dataset], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable[..., Dataset], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> Dataset: """Apply a function to each Dataset in the group and concatenate them together into a new Dataset. @@ -1818,15 +1820,15 @@ def _combine(self, applied): return combined def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> Dataset: """Reduce the items in this group by applying `func` along some dimension(s). diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 7b8b9db29f8..31869ea0e2f 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3860,8 +3860,8 @@ def test_groupby_multi_map(): }, dims=["ny", "nx"], ) - - d.groupby(("lon", "lat")).map(lambda x: x) + xr.testing.assert_equal(d,d.groupby("lon").map(lambda x: x)) + xr.testing.assert_equal(d, d.groupby(("lon", "lat")).map(lambda x: x)) # TODO: Possible property tests to add to this module From 102787ffb0491c26160f50c861e539d6af1a8db4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Dec 2025 19:51:51 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 220 +++++++++++++++++------------------ xarray/tests/test_groupby.py | 2 +- 2 files changed, 109 insertions(+), 113 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 989f9e63a9f..eb0690d3532 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -142,10 +142,10 @@ def _consolidate_slices(slices: list[slice]) -> list[slice]: if not isinstance(slice_, slice): raise ValueError(f"list element is not a slice: {slice_!r}") if ( - result - and last_slice.stop == slice_.start - and _is_one_or_none(last_slice.step) - and _is_one_or_none(slice_.step) + result + and last_slice.stop == slice_.start + and _is_one_or_none(last_slice.step) + and _is_one_or_none(slice_.step) ): last_slice = slice(last_slice.start, slice_.stop, slice_.step) result[-1] = last_slice @@ -211,8 +211,7 @@ def data(self) -> np.ndarray: return np.arange(self.size, dtype=int) def __array__( - self, dtype: np.typing.DTypeLike | None = None, /, *, - copy: bool | None = None + self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None ) -> np.ndarray: if copy is False: raise NotImplementedError(f"An array copy is necessary, got {copy = }.") @@ -254,7 +253,7 @@ def to_array(self) -> DataArray: def _ensure_1d( - group: T_Group, obj: T_DataWithCoords + group: T_Group, obj: T_DataWithCoords ) -> tuple[ T_Group, T_DataWithCoords, @@ -348,7 +347,7 @@ def __post_init__(self) -> None: ) if not isinstance(self.group, _DummyGroup) and is_chunked_array( - self.group.variable._data + self.group.variable._data ): # This requires a pass to discover the groups present if isinstance(self.grouper, UniqueGrouper) and self.grouper.labels is None: @@ -357,7 +356,7 @@ def __post_init__(self) -> None: ) # this requires a pass to compute the bin edges if isinstance(self.grouper, BinGrouper) and isinstance( - self.grouper.bins, int + self.grouper.bins, int ): raise ValueError( "Please pass explicit bin edges to BinGrouper using the ``bins`` kwarg" @@ -384,11 +383,11 @@ def __len__(self) -> int: def _parse_group_and_groupers( - obj: T_Xarray, - group: GroupInput, - groupers: dict[str, Grouper], - *, - eagerly_compute_group: Literal[False] | None, + obj: T_Xarray, + group: GroupInput, + groupers: dict[str, Grouper], + *, + eagerly_compute_group: Literal[False] | None, ) -> tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray from xarray.groupers import Grouper, UniqueGrouper @@ -456,7 +455,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: def _resolve_group( - obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable + obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable ) -> T_Group: from xarray.core.dataarray import DataArray @@ -540,8 +539,7 @@ def factorize(self) -> EncodedGroups: # NaNs; as well as values outside the bins are coded by -1 # Restore these after the raveling broadcasted_masks = broadcast(*masks) - mask = functools.reduce(np.logical_or, - broadcasted_masks) # type: ignore[arg-type] + mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( @@ -637,10 +635,10 @@ class GroupBy(Generic[T_Xarray]): encoded: EncodedGroups def __init__( - self, - obj: T_Xarray, - groupers: tuple[ResolvedGrouper, ...], - restore_coord_dims: bool = True, + self, + obj: T_Xarray, + groupers: tuple[ResolvedGrouper, ...], + restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -663,8 +661,8 @@ def __init__( self.encoded = grouper.encoded else: if any( - isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) - for grouper in groupers + isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) + for grouper in groupers ): raise NotImplementedError( "Grouping by multiple variables, one of which " @@ -783,24 +781,24 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: return unstacked # type: ignore[return-value] def map( - self, - func: Callable, - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> T_Xarray: raise NotImplementedError() def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> T_Xarray: raise NotImplementedError() @@ -984,15 +982,15 @@ def _maybe_reindex(self, combined): the correct order here. """ has_missing_groups = ( - self.encoded.unique_coord.size != self.encoded.full_index.size + self.encoded.unique_coord.size != self.encoded.full_index.size ) indexers = {} for grouper in self.groupers: index = combined._indexes.get(grouper.name, None) if (has_missing_groups and index is not None) or ( - len(self.groupers) > 1 - and not isinstance(grouper.full_index, pd.RangeIndex) - and not (index is not None and index.index.equals(grouper.full_index)) + len(self.groupers) > 1 + and not isinstance(grouper.full_index, pd.RangeIndex) + and not (index is not None and index.index.equals(grouper.full_index)) ): indexers[grouper.name] = grouper.full_index if indexers: @@ -1024,17 +1022,17 @@ def _maybe_unstack(self, obj): grouper.name for grouper in self.groupers if isinstance(grouper.group, _DummyGroup) - and isinstance(grouper.grouper, UniqueGrouper) + and isinstance(grouper.grouper, UniqueGrouper) ] obj = obj.drop_vars(to_drop) return obj def _flox_reduce( - self, - dim: Dims, - keep_attrs: bool | None = None, - **kwargs: Any, + self, + dim: Dims, + keep_attrs: bool | None = None, + **kwargs: Any, ) -> T_Xarray: """Adaptor function that translates our groupby API to that of flox.""" import flox @@ -1045,8 +1043,7 @@ def _flox_reduce( obj = self._original_obj variables = ( {k: v.variable for k, v in obj.data_vars.items()} - if isinstance(obj, - Dataset) # type: ignore[redundant-expr] # seems to be a mypy bug + if isinstance(obj, Dataset) # type: ignore[redundant-expr] # seems to be a mypy bug else obj._coords ) @@ -1074,11 +1071,11 @@ def _flox_reduce( name: var for name, var in variables.items() if ( - not _is_numeric_aggregatable_dtype(var) - # this avoids dropping any levels of a MultiIndex, which raises - # a warning - and name not in midx_grouping_vars - and name not in obj.dims + not _is_numeric_aggregatable_dtype(var) + # this avoids dropping any levels of a MultiIndex, which raises + # a warning + and name not in midx_grouping_vars + and name not in obj.dims ) } else: @@ -1098,7 +1095,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.codes.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -1112,13 +1109,12 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.codes.dims and d not in obj.dims for d in - parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") has_missing_groups = ( - self.encoded.unique_coord.size != self.encoded.full_index.size + self.encoded.unique_coord.size != self.encoded.full_index.size ) if self._by_chunked or has_missing_groups or kwargs.get("min_count", 0) > 0: # Xarray *always* returns np.nan when there are no observations in a group, @@ -1186,9 +1182,9 @@ def _flox_reduce( for name, var in variables.items(): dims_set = set(var.dims) if ( - dims_set <= set(parsed_dim) - and (dims_set & set(result.dims)) - and name not in result_variables + dims_set <= set(parsed_dim) + and (dims_set & set(result.dims)) + and name not in result_variables ): to_broadcast[name] = var for name, var in to_broadcast.items(): @@ -1233,14 +1229,14 @@ def fillna(self, value: Any) -> T_Xarray: return ops.fillna(self, value) def quantile( - self, - q: ArrayLike, - dim: Dims = None, - *, - method: QuantileMethods = "linear", - keep_attrs: bool | None = None, - skipna: bool | None = None, - interpolation: QuantileMethods | None = None, + self, + q: ArrayLike, + dim: Dims = None, + *, + method: QuantileMethods = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QuantileMethods | None = None, ) -> T_Xarray: """Compute the qth quantile over each array in the groups and concatenate them together into a new array. @@ -1362,10 +1358,10 @@ def quantile( q = np.asarray(q, dtype=np.float64) if ( - method == "linear" - and OPTIONS["use_flox"] - and contains_only_chunked_or_numpy(self._obj) - and module_available("flox", minversion="0.9.4") + method == "linear" + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + and module_available("flox", minversion="0.9.4") ): result = self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna @@ -1407,15 +1403,15 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: return ops.where_method(self, cond, other) def _first_or_last( - self, - op: Literal["first" | "last"], - skipna: bool | None, - keep_attrs: bool | None, + self, + op: Literal["first" | "last"], + skipna: bool | None, + keep_attrs: bool | None, ): if all( - isinstance(maybe_slice, slice) - and (maybe_slice.stop == maybe_slice.start + 1) - for maybe_slice in self.encoded.group_indices + isinstance(maybe_slice, slice) + and (maybe_slice.stop == maybe_slice.start + 1) + for maybe_slice in self.encoded.group_indices ): # NB. this is currently only used for reductions along an existing # dimension @@ -1423,9 +1419,9 @@ def _first_or_last( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) if ( - module_available("flox", minversion="0.10.0") - and OPTIONS["use_flox"] - and contains_only_chunked_or_numpy(self._obj) + module_available("flox", minversion="0.10.0") + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) ): import flox.xrdtypes @@ -1446,7 +1442,7 @@ def _first_or_last( return result def first( - self, skipna: bool | None = None, keep_attrs: bool | None = None + self, skipna: bool | None = None, keep_attrs: bool | None = None ) -> T_Xarray: """ Return the first element of each group along the group dimension @@ -1467,7 +1463,7 @@ def first( return self._first_or_last("first", skipna, keep_attrs) def last( - self, skipna: bool | None = None, keep_attrs: bool | None = None + self, skipna: bool | None = None, keep_attrs: bool | None = None ) -> T_Xarray: """ Return the last element of each group along the group dimension @@ -1562,11 +1558,11 @@ def lookup_order(dimension): return stacked def map( - self, - func: Callable[..., DataArray], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable[..., DataArray], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> DataArray: """Apply a function to each array in the group and concatenate them together into a new array. @@ -1656,15 +1652,15 @@ def _combine(self, applied, shortcut=False): return combined def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -1744,11 +1740,11 @@ def dims(self) -> Frozen[Hashable, int]: return FrozenMappingWarningOnValuesAccess(self._dims) def map( - self, - func: Callable[..., Dataset], - args: tuple[Any, ...] = (), - shortcut: bool | None = None, - **kwargs: Any, + self, + func: Callable[..., Dataset], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, ) -> Dataset: """Apply a function to each Dataset in the group and concatenate them together into a new Dataset. @@ -1820,15 +1816,15 @@ def _combine(self, applied): return combined def reduce( - self, - func: Callable[..., Any], - dim: Dims = None, - *, - axis: int | Sequence[int] | None = None, - keep_attrs: bool | None = None, - keepdims: bool = False, - shortcut: bool = True, - **kwargs: Any, + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ) -> Dataset: """Reduce the items in this group by applying `func` along some dimension(s). diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 31869ea0e2f..5075fa12b92 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3860,7 +3860,7 @@ def test_groupby_multi_map(): }, dims=["ny", "nx"], ) - xr.testing.assert_equal(d,d.groupby("lon").map(lambda x: x)) + xr.testing.assert_equal(d, d.groupby("lon").map(lambda x: x)) xr.testing.assert_equal(d, d.groupby(("lon", "lat")).map(lambda x: x)) From 99c102dbaf6c72fe6543cd640e1b4c082cdf3088 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 15 Dec 2025 19:48:28 -0700 Subject: [PATCH 8/8] Add T_DataArray type import in groupby.py --- xarray/core/groupby.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index eb0690d3532..77ea1b341c9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -38,6 +38,7 @@ from xarray.core.types import ( Dims, QuantileMethods, + T_DataArray, T_DataWithCoords, T_Xarray, )