diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 445ffd96e8..663e3829e2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5685,6 +5685,12 @@ async def remove_worker( f"Removing worker {ws.address!r} caused the cluster to lose scattered " f"data, which can't be recovered: {lost_keys} ({stimulus_id=})" ) + if not expected and processing_keys: + logger.warning( + f"Worker {ws.address!r} dropped unexpectedly. " + f"Interrupting {len(processing_keys)} processing tasks: " + f"{processing_keys} ({stimulus_id=})" + ) event_msg = { "action": "remove-worker", diff --git a/distributed/stealing.py b/distributed/stealing.py index cc6f1e6d65..5aee84a017 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -118,6 +118,7 @@ def __init__(self, scheduler: Scheduler): self.metrics = { "request_count_total": defaultdict(int), "request_cost_total": defaultdict(int), + "reject_count_margin_total": defaultdict(int), } self._request_counter = 0 self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm @@ -486,10 +487,22 @@ def balance(self) -> None: comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) compute = self.scheduler._get_prefix_duration(ts.prefix) - if ( + + # Be conservative about marginal steals: require headroom equal + # to 50% of the thief's transfer cost to absorb estimation noise + # and routine network jitter. + margin = comm_cost_thief * 0.5 + + would_steal_without_margin = ( occ_thief + comm_cost_thief + compute <= occ_victim - (comm_cost_victim + compute) / 2 - ): + ) + would_steal_with_margin = ( + occ_thief + comm_cost_thief + compute + margin + <= occ_victim - (comm_cost_victim + compute) / 2 + ) + + if would_steal_with_margin: self.move_task_request(ts, victim, thief) cost = compute + comm_cost_victim log.append( @@ -520,6 +533,17 @@ def balance(self) -> None: # for removing ts from stealable. If we made sure to # properly clean up, we would not need this stealable.discard(ts) + elif would_steal_without_margin: + self.metrics["reject_count_margin_total"][level] += 1 + logger.debug( + "Work-stealing margin heuristic rejected steal of task %s " + "(thief=%s, victim=%s, level=%d, margin=%.4f)", + ts.key, + thief.address, + victim.address, + level, + margin, + ) self.scheduler.check_idle_saturated( victim, occ=combined_occupancy(victim) ) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 758a6e03c7..a09305fdd9 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -11,6 +11,7 @@ from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence from operator import mul from time import sleep +from unittest.mock import patch import pytest from tlz import merge, sliding_window @@ -1448,8 +1449,10 @@ def func(*args): "cost, ntasks, expect_steal", [ pytest.param(10, 10, False, id="not enough work to steal"), - pytest.param(10, 12, True, id="enough work to steal"), - pytest.param(20, 12, False, id="not enough work for increased cost"), + # The 50% margin heuristic raises the minimum backlog needed to justify + # stealing these expensive tasks; 12 was enough before, 17 is enough now. + pytest.param(10, 17, True, id="enough work to steal"), + pytest.param(20, 17, False, id="not enough work for increased cost"), ], ) def test_balance_expensive_tasks(cost, ntasks, expect_steal): @@ -2010,6 +2013,61 @@ def block(i: int, in_event: Event, block_event: Event) -> int: await block_event.set() +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 2, + config={ + "distributed.scheduler.work-stealing-interval": "100ms", + "distributed.scheduler.default-task-durations": {"slowidentity": 0.021}, + **NO_AMM, + }, +) +async def test_reject_count_margin_metric(c, s, a, b): + """ + Verify that the margin heuristic increments reject_count_margin_total + when a steal is suppressed that old logic would have permitted. + """ + steal = s.extensions["stealing"] + await steal.stop() + + # Use enough short tasks to satisfy Scheduler.check_idle_saturated() for a + # single busy worker while still keeping the steal in the margin-rejection + # window once get_comm_cost() is patched below. + futures = c.map( + slowidentity, + range(21), + workers=a.address, + allow_other_workers=True, + delay=0.021, + ) + + while len(s.tasks) < 21: + await asyncio.sleep(0.01) + + while len(a.state.tasks) < 21: + await asyncio.sleep(0.01) + + for ws in s.workers.values(): + s.check_idle_saturated(ws) + + a_ws = s.workers[a.address] + b_ws = s.workers[b.address] + assert a_ws in s.saturated, ( + f"Worker A not saturated: occupancy={a_ws.occupancy:.3f}, " + f"nthreads={a_ws.nthreads}, processing={len(a_ws.processing)}" + ) + assert ( + b_ws in s.idle.values() + ), f"Worker B not idle: processing={len(b_ws.processing)}" + + with patch.object( + s, "get_comm_cost", side_effect=lambda ts, ws: 0.3 if ws == b_ws else 0.0 + ): + steal.balance() + + assert sum(steal.metrics["reject_count_margin_total"].values()) >= 1 + + @gen_cluster( nthreads=[("", 1)], client=True, diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 5dbf62a430..39bf286d84 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2995,6 +2995,8 @@ async def test_log_remove_worker(c, s, a, b): "(stimulus_id='ungraceful')", f"Removing worker '{b.address}' caused the cluster to lose scattered " "data, which can't be recovered: {'z'} (stimulus_id='ungraceful')", + f"Worker {b.address!r} dropped unexpectedly. Interrupting 1 " + "processing tasks: {'y'} (stimulus_id='ungraceful')", "Lost all workers", ]