Skip to content

Commit

Permalink
Clean up tests after legacy DataFrame removal (#8972)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Dec 20, 2024
1 parent fd6149e commit 8f1b241
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 28 deletions.
14 changes: 3 additions & 11 deletions distributed/protocol/tests/test_highlevelgraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import contextlib

import pytest

np = pytest.importorskip("numpy")
Expand Down Expand Up @@ -175,11 +173,9 @@ async def test_dataframe_annotations(c, s, a, b):
acol = df["a"]
bcol = df["b"]

ctx = contextlib.nullcontext()
if dd._dask_expr_enabled():
ctx = pytest.warns(
UserWarning, match="Annotations will be ignored when using query-planning"
)
ctx = pytest.warns(
UserWarning, match="Annotations will be ignored when using query-planning"
)

with dask.annotate(retries=retries), ctx:
df = acol + bcol
Expand All @@ -189,7 +185,3 @@ async def test_dataframe_annotations(c, s, a, b):

assert rdf.dtypes == np.float64
assert (rdf == 10.0).all()

if not dd._dask_expr_enabled():
# There is an annotation match per partition (i.e. task)
assert plugin.retry_matches == df.npartitions
9 changes: 3 additions & 6 deletions distributed/shuffle/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,10 @@ async def test_basic_merge(c, s, a, b, how):

joined = a.merge(b, left_on="y", right_on="y", how=how)

if dd._dask_expr_enabled():
# Ensure we're using a hash join
from dask_expr._merge import HashJoinP2P
# Ensure we're using a hash join
from dask_expr._merge import HashJoinP2P

assert any(
isinstance(expr, HashJoinP2P) for expr in joined.optimize()._expr.walk()
)
assert any(isinstance(expr, HashJoinP2P) for expr in joined.optimize()._expr.walk())

expected = pd.merge(A, B, how, "y")
await list_eq(joined, expected)
Expand Down
4 changes: 1 addition & 3 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,9 +1637,7 @@ async def test_multi(c, s, a, b):
await assert_scheduler_cleanup(s)


@pytest.mark.skipif(
dd._dask_expr_enabled(), reason="worker restrictions are not supported in dask-expr"
)
@pytest.mark.skipif(reason="worker restrictions are not supported in dask-expr")
@gen_cluster(client=True)
async def test_restrictions(c, s, a, b):
df = dask.datasets.timeseries(
Expand Down
9 changes: 1 addition & 8 deletions distributed/shuffle/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
from distributed.core import PooledRPCCall
from distributed.shuffle._core import ShuffleId, ShuffleRun

UNPACK_PREFIX = "shuffle_p2p"
try:
import dask.dataframe as dd

if dd._dask_expr_enabled():
UNPACK_PREFIX = "p2pshuffle"
except ImportError:
pass
UNPACK_PREFIX = "p2pshuffle"


class PooledRPCShuffle(PooledRPCCall):
Expand Down

0 comments on commit 8f1b241

Please sign in to comment.