Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def _elastic_event_cleanup() -> None:
array.delete()


class ElasticRetryLimit:
"""A retry callback that limits the number of attempts."""

def __init__(self, max_attempts: int):
if max_attempts <= 0:
raise ValueError("max_attempts must be positive.")
self.max_attempts = max_attempts

def __call__(self, attempt: int, error: Exception) -> bool:
del error # Unused
return attempt < self.max_attempts


class Manager:
"""Utility class for elastic training.

Expand Down Expand Up @@ -191,12 +204,13 @@ def _monitor_new_slices(

def elastic_retry(
self,
max_retries: int,
max_retries: int | None = None,
minimum_slice_count: int | None = None,
poll_interval: float | int = 10,
timeout: float | None = None,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
retry_policy: Callable[[int, Exception], bool] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with elasticity fault tolerance.

Expand Down Expand Up @@ -224,6 +238,7 @@ def elastic_retry(

Args:
max_retries: The maximum number of times to retry the function.
Deprecated: Use `retry_policy` instead.
minimum_slice_count: The minimum number of slices required to run the
function. If None, defaults to the total number of slices.
poll_interval: The number of seconds to wait between activity checks.
Expand All @@ -233,6 +248,10 @@ def elastic_retry(
pre_callback: A callback to call before the function is attempted.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.
retry_policy: A policy (callable) to determine if a retry should be
attempted. It accepts the attempt number (1-indexed) and the exception
that triggered the retry. If it returns False, no more retries are
attempted.

Returns:
A decorator that retries the wrapped function.
Expand All @@ -248,17 +267,23 @@ def elastic_retry(
else minimum_slice_count
)

if max_retries <= 0:
raise ValueError("max_retries must be positive.")
if max_retries is not None and retry_policy is not None:
raise ValueError("Cannot specify both max_retries and retry_policy.")

if retry_policy is None:
if max_retries is None:
retry_policy = lambda attempt, error: True
else:
if max_retries <= 0:
raise ValueError("max_retries must be positive.")
retry_policy = ElasticRetryLimit(max_retries)

def decorator(func: _F) -> _F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:

def attempt_execution(retry_index: int) -> Any:
_logger.info(
"Elastic attempt %d out of %d", retry_index + 1, max_retries
)
def attempt_execution(attempt: int) -> Any:
_logger.info("Elastic attempt %d", attempt)
self.active_slice_indices = elastic.wait_for_slices(
slice_count=target_slice_count,
slice_to_devices=self.slice_to_devices,
Expand Down Expand Up @@ -289,34 +314,49 @@ def attempt_execution(retry_index: int) -> Any:
if monitor_thread is not None:
monitor_thread.join()

for retry_index in range(max_retries):
attempt = 1
while True:
try:
return attempt_execution(retry_index)
except ScaleUpSignalError:
_logger.info("Scale up requested. Retrying.")
return attempt_execution(attempt)
except ScaleUpSignalError as error:
_logger.info("Scale up requested.")
_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()

if not retry_policy(attempt, error):
_logger.info(
"Retry policy rejected retry after ScaleUpSignalError."
)
raise ElasticRuntimeError(
f"Elastic attempt {attempt} failed."
) from error

_logger.info("Retrying.")
except jax.errors.JaxRuntimeError as error:
if not elastic.is_error_due_to_slice_down(error):
raise

if self.new_slice_event.is_set():
_logger.info(
"Slice down event and new slice available detected. Retrying."
)
_logger.info("Slice down event and new slice available detected.")
else:
_logger.info("Slice down event detected. Retrying.")
_logger.info("Slice down event detected.")

_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()
else:
raise ElasticRuntimeError(
f"Elastic attempt {max_retries} out of {max_retries} failed."
)

if not retry_policy(attempt, error):
_logger.info("Retry policy rejected retry after JaxRuntimeError.")
raise ElasticRuntimeError(
f"Elastic attempt {attempt} failed."
) from error

_logger.info("Retrying.")

attempt += 1

return wrapper

Expand Down
Loading