Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 179 additions & 137 deletions benchmarks/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from benchmarks.utils.args_parser import get_parser
from benchmarks.utils.build_manifest import summarize_build_records
from benchmarks.utils.buildx_utils import (
BackgroundBuildKitPruner,
buildkit_disk_usage,
maybe_prune_buildkit_cache,
maybe_reset_buildkit,
Expand Down Expand Up @@ -747,6 +748,17 @@ def build_all_images(
# Prune aggressively by default; filters like "unused-for=12h" prevented GC from
# reclaiming layers created during the current run, leading to disk exhaustion.
prune_filters: list[str] | None = None
completed_batch_prune_enabled = (
os.getenv("BUILDKIT_COMPLETED_BATCH_PRUNE_ENABLED", "1") != "0"
)
completed_batch_prune_timeout_sec = int(
os.getenv("BUILDKIT_COMPLETED_BATCH_PRUNE_TIMEOUT_SEC", "300")
)
completed_batch_pruner = BackgroundBuildKitPruner(
enabled=completed_batch_prune_enabled,
timeout_sec=completed_batch_prune_timeout_sec,
)
prunable_batch: list[str] = []

def _chunks(seq: list[str], size: int):
if size <= 0:
Expand All @@ -758,163 +770,193 @@ def _chunks(seq: list[str], size: int):
batches = list(_chunks(base_images, batch_size or len(base_images)))
total_batches = len(batches)

with (
_prepare_cached_sdist() as cached_sdist,
manifest_file.open("w") as writer,
tqdm(
total=len(base_images), desc="Building agent-server images", leave=True
) as pbar,
):
_update_pbar(pbar, built, skipped, failures, 0, None, "Queueing")

for batch_idx, batch in enumerate(batches, start=1):
if not batch:
continue
try:
with (
_prepare_cached_sdist() as cached_sdist,
manifest_file.open("w") as writer,
tqdm(
total=len(base_images), desc="Building agent-server images", leave=True
) as pbar,
):
_update_pbar(pbar, built, skipped, failures, 0, None, "Queueing")

for batch_idx, batch in enumerate(batches, start=1):
if not batch:
continue

batch_started_monotonic = time.monotonic()
logger.info(
"Starting batch %d/%d (%d images)", batch_idx, total_batches, len(batch)
)
in_progress: set[str] = set()
batch_built = 0
batch_skipped = 0
batch_failures = 0

with ProcessPoolExecutor(max_workers=max_workers) as ex:
futures = {}
for base in batch:
in_progress.add(base)
resolved_tag = (
base_image_to_custom_tag_fn(base)
if base_image_to_custom_tag_fn
else ""
)
fut = ex.submit(
_build_with_logging,
log_dir=build_log_dir,
base_image=base,
target_image=image,
custom_tag=resolved_tag,
target=target,
push=push,
force_build=force_build,
max_retries=max_retries,
post_build_fn=post_build_fn,
cached_sdist=cached_sdist,
extra_build_args=extra_build_args,
)
futures[fut] = base

_update_pbar(
pbar,
built,
skipped,
failures,
len(in_progress),
next(iter(in_progress), None),
f"Batch {batch_idx}/{total_batches} running",
)
completed_batch_pruner.poll()
if prunable_batch:
# Prune the previous batch (one-batch delay ensures we don't
# prune layers from the currently running batch).
completed_batch_pruner.enqueue_completed_batch(prunable_batch)
prunable_batch = []
Comment thread
neubig marked this conversation as resolved.

for fut in as_completed(futures):
base = futures[fut]
status = None
try:
result: BuildOutput = fut.result()
except Exception as e:
logger.error("Build failed for %s: %r", base, e)
result = BuildOutput(
batch_started_monotonic = time.monotonic()
logger.info(
"Starting batch %d/%d (%d images)",
batch_idx,
total_batches,
len(batch),
)
in_progress: set[str] = set()
completed_in_batch: list[str] = []
batch_built = 0
batch_skipped = 0
batch_failures = 0

with ProcessPoolExecutor(max_workers=max_workers) as ex:
futures = {}
for base in batch:
in_progress.add(base)
resolved_tag = (
base_image_to_custom_tag_fn(base)
if base_image_to_custom_tag_fn
else ""
)
fut = ex.submit(
_build_with_logging,
log_dir=build_log_dir,
base_image=base,
tags=[],
error=repr(e),
status="failed",
target_image=image,
custom_tag=resolved_tag,
target=target,
push=push,
force_build=force_build,
max_retries=max_retries,
post_build_fn=post_build_fn,
cached_sdist=cached_sdist,
extra_build_args=extra_build_args,
)
futures[fut] = base

writer.write(result.model_dump_json() + "\n")
writer.flush()
results.append(result)

with mu:
if result.error or not result.tags:
failures += 1
batch_failures += 1
status = "❌ Failed"
elif result.status == "skipped_remote_exists":
skipped += 1
batch_skipped += 1
status = "⏭ Skipped"
else:
built += 1
batch_built += 1
status = "✅ Built"

in_progress.discard(base)
pbar.update(1)
_update_pbar(
pbar,
built,
skipped,
failures,
len(in_progress),
next(iter(in_progress), None),
status,
f"Batch {batch_idx}/{total_batches} running",
)
logger.debug(
"Image %s completed status=%s attempts=%d duration=%ss build=%ss remote_check=%ss post_build=%ss",
base,
result.status,
result.attempt_count,
result.duration_seconds,
result.build_seconds,
result.remote_check_seconds,
result.post_build_seconds,

for fut in as_completed(futures):
base = futures[fut]
status = None
try:
result: BuildOutput = fut.result()
except Exception as e:
logger.error("Build failed for %s: %r", base, e)
result = BuildOutput(
base_image=base,
tags=[],
error=repr(e),
status="failed",
)

writer.write(result.model_dump_json() + "\n")
writer.flush()
results.append(result)

with mu:
if result.error or not result.tags:
failures += 1
batch_failures += 1
status = "❌ Failed"
elif result.status == "skipped_remote_exists":
skipped += 1
batch_skipped += 1
status = "⏭ Skipped"
else:
built += 1
batch_built += 1
completed_in_batch.append(base)
status = "✅ Built"

in_progress.discard(base)
pbar.update(1)
_update_pbar(
pbar,
built,
skipped,
failures,
len(in_progress),
next(iter(in_progress), None),
status,
)
logger.debug(
"Image %s completed status=%s attempts=%d duration=%ss build=%ss remote_check=%ss post_build=%ss",
base,
result.status,
result.attempt_count,
result.duration_seconds,
result.build_seconds,
result.remote_check_seconds,
result.post_build_seconds,
)

used, total = buildkit_disk_usage()
if total > 0:
logger.info(
"BuildKit usage after batch %d/%d: %.2f%% (%0.2f GiB / %0.2f GiB)",
batch_idx,
total_batches,
(used / total) * 100,
used / (1 << 30),
total / (1 << 30),
)

used, total = buildkit_disk_usage()
if total > 0:
prunable_batch = completed_in_batch

if prune_keep_storage_gb and prune_keep_storage_gb > 0:
if completed_batch_pruner.is_busy:
logger.info(
"Skipping synchronous BuildKit prune after batch %d/%d while background targeted prune is still running or queued",
batch_idx,
total_batches,
)
else:
pruned = maybe_prune_buildkit_cache(
keep_storage_gb=prune_keep_storage_gb,
threshold_pct=prune_threshold_pct,
filters=prune_filters,
)
if pruned:
logger.info(
"Pruned BuildKit cache after batch %d/%d (keep=%d GiB, threshold=%.1f%%)",
batch_idx,
total_batches,
prune_keep_storage_gb,
prune_threshold_pct,
)
else:
logger.info(
"No prune needed after batch %d/%d (threshold %.1f%%)",
batch_idx,
total_batches,
prune_threshold_pct,
)

batch_duration = time.monotonic() - batch_started_monotonic
batch_throughput = (
(batch_built / batch_duration) * 3600 if batch_duration else 0.0
)
logger.info(
"BuildKit usage after batch %d/%d: %.2f%% (%0.2f GiB / %0.2f GiB)",
"Finished batch %d/%d in %.1fs: built=%d skipped=%d failed=%d throughput=%.1f built images/hour",
batch_idx,
total_batches,
(used / total) * 100,
used / (1 << 30),
total / (1 << 30),
batch_duration,
batch_built,
batch_skipped,
batch_failures,
batch_throughput,
)

if prune_keep_storage_gb and prune_keep_storage_gb > 0:
pruned = maybe_prune_buildkit_cache(
keep_storage_gb=prune_keep_storage_gb,
threshold_pct=prune_threshold_pct,
filters=prune_filters,
)
if pruned:
logger.info(
"Pruned BuildKit cache after batch %d/%d (keep=%d GiB, threshold=%.1f%%)",
batch_idx,
total_batches,
prune_keep_storage_gb,
prune_threshold_pct,
)
else:
logger.info(
"No prune needed after batch %d/%d (threshold %.1f%%)",
batch_idx,
total_batches,
prune_threshold_pct,
)
batch_duration = time.monotonic() - batch_started_monotonic
batch_throughput = (
(batch_built / batch_duration) * 3600 if batch_duration else 0.0
)
logger.info(
"Finished batch %d/%d in %.1fs: built=%d skipped=%d failed=%d throughput=%.1f built images/hour",
batch_idx,
total_batches,
batch_duration,
batch_built,
batch_skipped,
batch_failures,
batch_throughput,
)
completed_batch_pruner.poll()
if prunable_batch:
completed_batch_pruner.enqueue_completed_batch(prunable_batch)
prunable_batch = []
finally:
completed_batch_pruner.wait()

summary_file = build_dir / "build-summary.json"
summary = summarize_build_records(
Expand Down
Loading
Loading