Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5685,6 +5685,12 @@ async def remove_worker(
f"Removing worker {ws.address!r} caused the cluster to lose scattered "
f"data, which can't be recovered: {lost_keys} ({stimulus_id=})"
)
if not expected and processing_keys:
logger.warning(
f"Worker {ws.address!r} dropped unexpectedly. "
f"Interrupting {len(processing_keys)} processing tasks: "
f"{processing_keys} ({stimulus_id=})"
)

event_msg = {
"action": "remove-worker",
Expand Down
28 changes: 26 additions & 2 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, scheduler: Scheduler):
self.metrics = {
"request_count_total": defaultdict(int),
"request_cost_total": defaultdict(int),
"reject_count_margin_total": defaultdict(int),
}
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
Expand Down Expand Up @@ -486,10 +487,22 @@ def balance(self) -> None:
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler._get_prefix_duration(ts.prefix)
if (

# Be conservative about marginal steals: require headroom equal
# to 50% of the thief's transfer cost to absorb estimation noise
# and routine network jitter.
margin = comm_cost_thief * 0.5

would_steal_without_margin = (
occ_thief + comm_cost_thief + compute
<= occ_victim - (comm_cost_victim + compute) / 2
):
)
would_steal_with_margin = (
occ_thief + comm_cost_thief + compute + margin
<= occ_victim - (comm_cost_victim + compute) / 2
)

if would_steal_with_margin:
self.move_task_request(ts, victim, thief)
cost = compute + comm_cost_victim
log.append(
Expand Down Expand Up @@ -520,6 +533,17 @@ def balance(self) -> None:
# for removing ts from stealable. If we made sure to
# properly clean up, we would not need this
stealable.discard(ts)
elif would_steal_without_margin:
self.metrics["reject_count_margin_total"][level] += 1
logger.debug(
"Work-stealing margin heuristic rejected steal of task %s "
"(thief=%s, victim=%s, level=%d, margin=%.4f)",
ts.key,
thief.address,
victim.address,
level,
margin,
)
self.scheduler.check_idle_saturated(
victim, occ=combined_occupancy(victim)
)
Expand Down
62 changes: 60 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
from operator import mul
from time import sleep
from unittest.mock import patch

import pytest
from tlz import merge, sliding_window
Expand Down Expand Up @@ -1448,8 +1449,10 @@ def func(*args):
"cost, ntasks, expect_steal",
[
pytest.param(10, 10, False, id="not enough work to steal"),
pytest.param(10, 12, True, id="enough work to steal"),
pytest.param(20, 12, False, id="not enough work for increased cost"),
# The 50% margin heuristic raises the minimum backlog needed to justify
# stealing these expensive tasks; 12 was enough before, 17 is enough now.
pytest.param(10, 17, True, id="enough work to steal"),
pytest.param(20, 17, False, id="not enough work for increased cost"),
],
)
def test_balance_expensive_tasks(cost, ntasks, expect_steal):
Expand Down Expand Up @@ -2010,6 +2013,61 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
await block_event.set()


@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 2,
config={
"distributed.scheduler.work-stealing-interval": "100ms",
"distributed.scheduler.default-task-durations": {"slowidentity": 0.021},
**NO_AMM,
},
)
async def test_reject_count_margin_metric(c, s, a, b):
"""
Verify that the margin heuristic increments reject_count_margin_total
when a steal is suppressed that old logic would have permitted.
"""
steal = s.extensions["stealing"]
await steal.stop()

# Use enough short tasks to satisfy Scheduler.check_idle_saturated() for a
# single busy worker while still keeping the steal in the margin-rejection
# window once get_comm_cost() is patched below.
futures = c.map(
slowidentity,
range(21),
workers=a.address,
allow_other_workers=True,
delay=0.021,
)

while len(s.tasks) < 21:
await asyncio.sleep(0.01)

while len(a.state.tasks) < 21:
await asyncio.sleep(0.01)

for ws in s.workers.values():
s.check_idle_saturated(ws)

a_ws = s.workers[a.address]
b_ws = s.workers[b.address]
assert a_ws in s.saturated, (
f"Worker A not saturated: occupancy={a_ws.occupancy:.3f}, "
f"nthreads={a_ws.nthreads}, processing={len(a_ws.processing)}"
)
assert (
b_ws in s.idle.values()
), f"Worker B not idle: processing={len(b_ws.processing)}"

with patch.object(
s, "get_comm_cost", side_effect=lambda ts, ws: 0.3 if ws == b_ws else 0.0
):
steal.balance()

assert sum(steal.metrics["reject_count_margin_total"].values()) >= 1


@gen_cluster(
nthreads=[("", 1)],
client=True,
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2995,6 +2995,8 @@ async def test_log_remove_worker(c, s, a, b):
"(stimulus_id='ungraceful')",
f"Removing worker '{b.address}' caused the cluster to lose scattered "
"data, which can't be recovered: {'z'} (stimulus_id='ungraceful')",
f"Worker {b.address!r} dropped unexpectedly. Interrupting 1 "
"processing tasks: {'y'} (stimulus_id='ungraceful')",
"Lost all workers",
]

Expand Down
Loading