diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index d6b245983..3ff4db49a 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -405,7 +405,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]): __bool__ = ... def union( self, other: list[HashableT] | Self, sort: bool | None = None - ) -> Index: ... + ) -> Self: ... def intersection( self, other: list[S1] | Self, sort: bool | None = False ) -> Self: ... diff --git a/pandas-stubs/core/indexes/multi.pyi b/pandas-stubs/core/indexes/multi.pyi index e66c845ec..7dc76e7f7 100644 --- a/pandas-stubs/core/indexes/multi.pyi +++ b/pandas-stubs/core/indexes/multi.pyi @@ -135,7 +135,7 @@ class MultiIndex(Index): def append(self, other): ... def repeat(self, repeats, axis=...): ... def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def swaplevel(self, i: int = -2, j: int = -1): ... + def swaplevel(self, i: int = -2, j: int = -1) -> Self: ... def reorder_levels(self, order): ... def sortlevel( self, diff --git a/pandas-stubs/core/indexes/range.pyi b/pandas-stubs/core/indexes/range.pyi index 4cb9ab35d..9b200d026 100644 --- a/pandas-stubs/core/indexes/range.pyi +++ b/pandas-stubs/core/indexes/range.pyi @@ -80,7 +80,7 @@ class RangeIndex(_IndexSubclassBase[int, np.int64]): def all(self, *args: Any, **kwargs: Any) -> bool: ... def any(self, *args: Any, **kwargs: Any) -> bool: ... @final - def union( + def union( # type: ignore[override] self, other: list[HashableT] | Index, sort: bool | None = None ) -> Index | Index[int] | RangeIndex: ... @overload # type: ignore[override] diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index 3b2960b3c..9182cd550 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -333,12 +333,18 @@ def test_range_index_union() -> None: def test_index_union_sort() -> None: """Test sort argument in pd.Index.union GH1264.""" check( - assert_type(pd.Index(["e", "f"]).union(["a", "b", "c"], sort=True), pd.Index), + assert_type( + pd.Index(["e", "f"]).union(["a", "b", "c"], sort=True), "pd.Index[str]" + ), pd.Index, + str, ) check( - assert_type(pd.Index(["e", "f"]).union(["a", "b", "c"], sort=False), pd.Index), + assert_type( + pd.Index(["e", "f"]).union(["a", "b", "c"], sort=False), "pd.Index[str]" + ), pd.Index, + str, ) @@ -1601,3 +1607,18 @@ def test_to_series() -> None: np.complexfloating, ) check(assert_type(Index(["1"]).to_series(), "pd.Series[str]"), pd.Series, str) + + +def test_multiindex_union() -> None: + """Test that MultiIndex.union returns MultiIndex""" + mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) + mi2 = pd.MultiIndex.from_product([["a", "b"], [3, 4]], names=["let", "num"]) + + check(assert_type(mi.union(mi2), "pd.MultiIndex"), pd.MultiIndex) + check(assert_type(mi.union([("c", 3), ("d", 4)]), "pd.MultiIndex"), pd.MultiIndex) + + +def test_multiindex_swaplevel() -> None: + """Test that MultiIndex.swaplevel returns MultiIndex""" + mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) + check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)