feat: support concurrency in base sampling strategy#1175
Conversation
Signed-off-by: Jake LoRocco <jake.lorocco@ibm.com> Assisted-by: CLAUDE:OPUS
2a3dbc1 to
e72ebde
Compare
planetf1
left a comment
There was a problem hiding this comment.
A few cross-cutting observations alongside the line-level notes — mostly about behaviour shifts that users and plugin authors will hit but the diff itself doesn't surface. Happy to be told any of these are deliberate scope cuts for a follow-up.
Cost & rate limits. concurrency_budget=N multiplies the worst-case request count per sample() by N (and the expected count by some factor between 1 and N depending on success rate). For paid backends that's a real spend multiplier, and against rate-limited ones (Anthropic/OpenAI/Watsonx) we'll hit 429s sooner — possibly turning a previously-passing sample loop into a hard failure. The class docstring describes the mechanic but doesn't warn about either. Probably worth a one-line note in the Args: block, even if rate-limit handling is a separate piece of work.
Determinism. Two identical sample() calls with concurrency_budget>1 can return different winning slices depending on which subsample's network call returns first. Existing qualitative tests and example notebooks that assume sampling stability may start to flake. A docstring note saying "selection is non-deterministic when concurrency_budget>1" would set expectations.
with_context(sampling_iteration=...) under concurrency. contextvars is the right primitive for this (survives create_task cleanly), and a quick read suggests it'll be fine, but I didn't see a test that pins the behaviour — i.e. that when subsample A is on iteration 3 and subsample B is on iteration 7, plugin handlers see the right iteration number for the slice they're processing. If this ever bleeds, every SAMPLING_ITERATION-driven plugin is wrong under concurrency. A small assertion test would lock it down.
Cancellation cleanup. On early success the producers are cancelled mid-generate_from_context. Backend cleanup paths for in-flight HTTP / streaming response handles vary across providers. Not aware of an existing leak, but there's no test asserting "after early success, no in-flight backend calls remain". Could surface later as flaky test resource warnings under load.
Hook ordering. SAMPLING_ITERATION events now interleave across subsamples — the iteration field is globally unique (subsample_index * iterations + i + 1, nice), but consumers that assumed monotonic ordering will misbehave. Worth a one-line note on SamplingIterationPayload.
Docs & examples. Couldn't find a worked example of concurrency_budget in docs/examples/ or a callout in docs/AGENTS_TEMPLATE.md. Given how visible the speedup will be once people find it, an example showing "here's how to use it, here's the cost trade-off" would prevent both under-use and footgun-misuse.
Test gaps that fall out of the line-level points. Three the suite would benefit from:
- Backend exception under
concurrency_budget=1(covers the failure mode in the gather/_DONEcomment). MultiTurnStrategywithconcurrency_budget>1exhausting all attempts, asserting which sliceselect_from_failureactually picks (covers the ordering note).- Cancellation cleanup — assert no producer tasks remain pending after early success.
None of this is blocking from my side; mostly flagging for visibility so we don't ship the concurrency knob without the matching guardrails.
| t.cancel() # No-op if already done / cancelled. | ||
|
|
||
| # Wait for cancellations to settle so we don't leak tasks. | ||
| await asyncio.gather(*producer_tasks, return_exceptions=True) |
There was a problem hiding this comment.
One thing to flag on the cleanup path — if the backend raises inside generate_from_context, the exception gets absorbed by _producer's finally (which still puts _DONE) and then dropped by gather(return_exceptions=True). The consumer ends up with an empty slices and the user sees AssertionError: result index cannot be out of range from SamplingResult.__init__, with nothing pointing back at the real cause.
In the pre-PR sequential path the exception propagated directly. Wondering if it's worth re-raising the first non-cancellation exception from the gather when no slices made it through, just for the default concurrency_budget=1 case:
| await asyncio.gather(*producer_tasks, return_exceptions=True) | |
| _gathered = await asyncio.gather(*producer_tasks, return_exceptions=True) | |
| if not slices: | |
| _exc = next( | |
| (r for r in _gathered if isinstance(r, BaseException) and not isinstance(r, asyncio.CancelledError)), | |
| None, | |
| ) | |
| if _exc: | |
| raise RuntimeError("all sampling subsamples failed to produce a result") from _exc |
There was a problem hiding this comment.
Done. Thanks for catching this. Added a few tests for this as well.
| @@ -187,205 +259,288 @@ async def sample( | |||
| ) | |||
| effective_loop_budget = start_payload.loop_budget | |||
There was a problem hiding this comment.
One thing to flag: a SAMPLING_LOOP_START hook can return loop_budget=0 (or negative), and the post-hook value is taken as-is. total_possible_generations then collapses to 0, no slices are produced, and we land on the same opaque AssertionError as the swallowed-exception case. The constructor assert validates self.loop_budget but the hook bypasses that.
| effective_loop_budget = start_payload.loop_budget | |
| effective_loop_budget = start_payload.loop_budget | |
| assert effective_loop_budget > 0, ( | |
| "SAMPLING_LOOP_START hook returned loop_budget <= 0; refusing to run." | |
| ) |
There was a problem hiding this comment.
Agreed. Added another check and a test for this.
|
|
||
|
|
||
| # Module-level counter used by the "every 5th call passes" requirement below. | ||
| _validation_counter = 0 |
There was a problem hiding this comment.
The module-level _validation_counter survives across tests in the same worker, and the global reads make the dependency easy to miss when reading the test in isolation. Could be lifted into the test function as a nonlocal closure variable instead — the surrounding tests already use that pattern. Not load-bearing, just a small nudge for future-us.
There was a problem hiding this comment.
Agreed. Changed.
| all_results=sampled_results, | ||
| all_validations=sampled_scores, | ||
| success=s_result.success, | ||
| iterations_used=len(slices), |
There was a problem hiding this comment.
iterations_used previously meant "how many sequential generate/validate cycles ran" — under concurrency it's now the total slice count across all subsamples (so up to loop_budget * concurrency_budget). Both readings are reasonable, but the field name and SamplingLoopEndPayload's docstring haven't moved with it, so an existing plugin will silently start seeing different numbers. Either a rename (slices_observed?) or a one-line note on the payload would probably be enough.
There was a problem hiding this comment.
Added a comment.
| if progress_indicator is not None: | ||
| progress_indicator.close() | ||
|
|
||
| s_result = _get_sampling_result( |
There was a problem hiding this comment.
slices arrives in queue order rather than per-subsample iteration order, so it's now interleaved across concurrent subsamples. Strategies whose select_from_failure returns -1 (e.g. MultiTurnStrategy) used to mean "the deepest-repaired turn"; with concurrency_budget>1 it's just whichever subsample finished last. Might be worth sorting by (subsample_index, iteration) here before handing off, or noting in MultiTurnStrategy.select_from_failure's docstring that the contract shifts under concurrency.
There was a problem hiding this comment.
Added a comment that it's non-deterministic; but the behavior is basically the same, the last slice added to the queue will still be the last turn of one of the concurrent multi-turn subsamples.
jakelorocco
left a comment
There was a problem hiding this comment.
For the other comments in the main body:
- There is a test for unique sampling iteration ids: test_sampling_iteration_ids_unique_under_concurrency.
- Added a comment to SamplingIterationPayload for monotonicity.
- Modified an example to show the concurrency budget option.
| @@ -187,205 +259,288 @@ async def sample( | |||
| ) | |||
| effective_loop_budget = start_payload.loop_budget | |||
There was a problem hiding this comment.
Agreed. Added another check and a test for this.
| if progress_indicator is not None: | ||
| progress_indicator.close() | ||
|
|
||
| s_result = _get_sampling_result( |
There was a problem hiding this comment.
Added a comment that it's non-deterministic; but the behavior is basically the same, the last slice added to the queue will still be the last turn of one of the concurrent multi-turn subsamples.
| all_results=sampled_results, | ||
| all_validations=sampled_scores, | ||
| success=s_result.success, | ||
| iterations_used=len(slices), |
There was a problem hiding this comment.
Added a comment.
Signed-off-by: Jake LoRocco <jake.lorocco@ibm.com> Assisted-by: CLAUDE:OPUS
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | ||
| corresponds to the loop_budget * concurrency budget. |
There was a problem hiding this comment.
The iterations_used docstring says "corresponds to the loop_budget * concurrency budget" but sample() passes len(slices) — the count of slices that actually completed, which can be 1 if the strategy exits on the first success. Also grammar error: "iterations the executed".
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | |
| corresponds to the loop_budget * concurrency budget. | |
| iterations_used: Total number of sampling iterations that completed. With concurrency | |
| enabled, this may be less than ``loop_budget * concurrency_budget`` if the strategy | |
| exits early after a successful result. |
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | ||
| corresponds to the loop_budget * concurrency budget. |
There was a problem hiding this comment.
The iterations_used docstring says "corresponds to the loop_budget * concurrency budget" but sample() passes len(slices) — the count of slices that actually completed, which can be 1 if the strategy exits on the first success. Also grammar error: "iterations the executed".
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | |
| corresponds to the loop_budget * concurrency budget. | |
| iterations_used: Total number of sampling iterations that completed. With concurrency | |
| enabled, this may be less than ``loop_budget * concurrency_budget`` if the strategy | |
| exits early after a successful result. |
| backend: Backend, | ||
| requirements: list[Requirement], | ||
| *, | ||
| validation_ctx: Context | None = None, |
There was a problem hiding this comment.
_subsample_iteration accepts validation_ctx but never passes it to mfuncs.avalidate — the call at line 478 uses context=result_ctx unconditionally. Any caller supplying a custom validation context will have it silently ignored. Was this intentional, or should it be context=validation_ctx or result_ctx?
There was a problem hiding this comment.
This is intentional. Our current sampling strategy has the same issue and there's a larger open issue to fix how validation handles its inputs / contexts.
| # If no slices made it through, surface the first non-cancellation | ||
| # exception from a producer rather than letting the empty-slices | ||
| # state crash later in SamplingResult with a misleading assertion. | ||
| if not slices: |
There was a problem hiding this comment.
The re-raise guard only fires when not slices. If one subsample succeeds while a sibling raises a backend error (network failure, quota exceeded, etc.), the exception is captured in producer_results but silently discarded — the caller gets a success with no indication that part of the concurrent work failed.
Whether returning the success is the right trade-off is a design call, but at minimum the failure should be logged so it is visible in production:
for r in producer_results:
if isinstance(r, BaseException) and not isinstance(r, asyncio.CancelledError):
flog.warning("A concurrent subsample raised an exception: %s", r)| assert loop_budget > 0, "Loop budget must be at least 1." | ||
| assert concurrency_budget > 0, "Concurrency budget must be at least 1." |
There was a problem hiding this comment.
assert statements are stripped when Python runs with optimisations (-O). Since sample() already uses raise ValueError for the hook-mutated budget path (line 269), it is worth being consistent here so invalid values are always rejected at construction.
| assert loop_budget > 0, "Loop budget must be at least 1." | |
| assert concurrency_budget > 0, "Concurrency budget must be at least 1." | |
| if loop_budget < 1: | |
| raise ValueError("Loop budget must be at least 1.") | |
| if concurrency_budget < 1: | |
| raise ValueError("Concurrency budget must be at least 1.") |
The class-level Raises: AssertionError entry (line 131) would need updating to ValueError too.
| ) | ||
|
|
||
| assert s_result.result_index < len(s_result.sample_generations), ( | ||
| "The select_from_failure method did not return a valid result. It has to selected from failed_results." |
There was a problem hiding this comment.
Minor grammar nit.
| "The select_from_failure method did not return a valid result. It has to selected from failed_results." | |
| "The select_from_failure method did not return a valid result. It must return an index into failed_results." |
| # Additionally, you could specify 3 concurrent requests. | ||
| # This will potentially result in wasted samples and a high-number of | ||
| # requests due to concurrent validation as well. |
There was a problem hiding this comment.
With loop_budget=1, concurrency_budget=3 this is exactly 3 requests — "high number" overstates it. The real concern is rate-limiting on paid backends.
| # Additionally, you could specify 3 concurrent requests. | |
| # This will potentially result in wasted samples and a high-number of | |
| # requests due to concurrent validation as well. | |
| # Additionally, you could run 3 concurrent subsamples — stopping as soon as one | |
| # passes validation. This cuts wall-clock latency but multiplies request count, | |
| # which can trigger rate limits on paid backends. |
planetf1
left a comment
There was a problem hiding this comment.
Review pass — 7 findings (docstring accuracy, CI gate gap, silent parameter, exception visibility, assert vs ValueError, grammar, example copy). Suggestions can be applied with one click where provided.
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | ||
| corresponds to the loop_budget * concurrency budget. |
There was a problem hiding this comment.
The iterations_used docstring says "corresponds to the loop_budget * concurrency budget" but sample() passes len(slices) — the count of slices that actually completed, which can be 1 if the strategy exits on the first success. Also grammar error: "iterations the executed".
| iterations_used: Total number of iterations the executed. For concurrent sampling, this | |
| corresponds to the loop_budget * concurrency budget. | |
| iterations_used: Total number of sampling iterations that completed. With concurrency | |
| enabled, this may be less than ``loop_budget * concurrency_budget`` if the strategy | |
| exits early after a successful result. |
| raise ValueError( | ||
| f"SAMPLING_LOOP_START hook returned non-positive loop_budget=" | ||
| f"{effective_loop_budget}; must be >= 1." | ||
| ) |
There was a problem hiding this comment.
This raise ValueError is not listed in the Raises: section of sample()'s docstring (currently only AssertionError is documented, a few lines above at line 225). The project's audit_coverage.py --quality CI gate enforces that every raise in a public function has a matching Raises: entry, so this will fail the build-and-validate job. The docstring needs:
ValueError: If a ``SAMPLING_LOOP_START`` hook returns a non-positive ``loop_budget``.
| backend: Backend, | ||
| requirements: list[Requirement], | ||
| *, | ||
| validation_ctx: Context | None = None, |
There was a problem hiding this comment.
_subsample_iteration accepts validation_ctx but never passes it to mfuncs.avalidate — the call at line 478 uses context=result_ctx unconditionally. Any caller supplying a custom validation context will have it silently ignored. Was this intentional, or should it be context=validation_ctx or result_ctx?
| # If no slices made it through, surface the first non-cancellation | ||
| # exception from a producer rather than letting the empty-slices | ||
| # state crash later in SamplingResult with a misleading assertion. | ||
| if not slices: |
There was a problem hiding this comment.
The re-raise guard only fires when not slices. If one subsample succeeds while a sibling raises a backend error (network failure, quota exceeded, etc.), the exception is captured in producer_results but silently discarded — the caller gets a success with no indication that part of the concurrent work failed.
Whether returning the success is the right trade-off is a design call, but at minimum the failure should be logged so it is visible in production:
for r in producer_results:
if isinstance(r, BaseException) and not isinstance(r, asyncio.CancelledError):
flog.warning("A concurrent subsample raised an exception: %s", r)| assert loop_budget > 0, "Loop budget must be at least 1." | ||
| assert concurrency_budget > 0, "Concurrency budget must be at least 1." |
There was a problem hiding this comment.
assert statements are stripped when Python runs with optimisations (-O). Since sample() already uses raise ValueError for the hook-mutated budget path (line 269), it is worth being consistent here so invalid values are always rejected at construction.
| assert loop_budget > 0, "Loop budget must be at least 1." | |
| assert concurrency_budget > 0, "Concurrency budget must be at least 1." | |
| if loop_budget < 1: | |
| raise ValueError("Loop budget must be at least 1.") | |
| if concurrency_budget < 1: | |
| raise ValueError("Concurrency budget must be at least 1.") |
The class-level Raises: AssertionError entry (line 131) would need updating to ValueError too.
| ) | ||
|
|
||
| assert s_result.result_index < len(s_result.sample_generations), ( | ||
| "The select_from_failure method did not return a valid result. It has to selected from failed_results." |
There was a problem hiding this comment.
Minor grammar nit.
| "The select_from_failure method did not return a valid result. It has to selected from failed_results." | |
| "The select_from_failure method did not return a valid result. It must return an index into failed_results." |
| # Additionally, you could specify 3 concurrent requests. | ||
| # This will potentially result in wasted samples and a high-number of | ||
| # requests due to concurrent validation as well. |
There was a problem hiding this comment.
With loop_budget=1, concurrency_budget=3 this is exactly 3 requests — "high number" overstates it. The real concern is rate-limiting on paid backends.
| # Additionally, you could specify 3 concurrent requests. | |
| # This will potentially result in wasted samples and a high-number of | |
| # requests due to concurrent validation as well. | |
| # Additionally, you could run 3 concurrent subsamples — stopping as soon as one | |
| # passes validation. This cuts wall-clock latency but multiplies request count, | |
| # which can trigger rate limits on paid backends. |
planetf1
left a comment
There was a problem hiding this comment.
Requesting changes on two confirmed blockers (detail in the inline comments). The new raise ValueError in sample() (base.py L269) has no matching Raises: entry, which fails the Docstring quality gate step in docs-publish.yml — that step runs on every PR touching mellea/** and has no continue-on-error, so build-and-validate will go red. Separately, validation_ctx is threaded into _subsample_iteration but avalidate validates against result_ctx, so the public validation_ctx parameter is silently ignored — a behavioural regression. The remaining five comments are nits the author can take at their discretion.
Signed-off-by: Jake LoRocco <jake.lorocco@ibm.com>
a9369d5 to
5fa4ada
Compare
|
@planetf1 I believe I've addressed all your comments. It seems like a fair number were duplicates / tests of some sort so please let me know if I missed one. |
Ah! sorry about that - unintentional (ai connectivity/retry). I think the only remaining issue is the missing raises in the docstring (in theory the build process should detect - but having checked it only looks for presence of raises, not values So we need a in the docstring around line 229 of base.py:sample so that the generated api docs are correct |
Pull Request
Issue
Fixes N/A; builds on previously closed PR (#240)
Description
Adds concurrency to our base sampling strategy. We lacked a way to concurrently sample and were only able to sample iteratively. The
sample()now manages distinct generators.subsample_iterationreplaces the actual sampling code that was previously insample().The total number of sampling iterations is now loop_budget * concurrency_budget. Concurrency budget is the breadth of the tree, loop_budget is the depth of the tree. At any time only concurrency_budget number of sampling iterations are running.
Testing
Attribution
Adding a new component, requirement, sampling strategy, or tool?
If your PR adds or modifies one of the types below, check the matching box. A checklist of type-specific review items will be posted as a comment.
NOTE: Please ensure you have an issue that has been acknowledged by a core contributor and routed you to open a pull request against this repository. Otherwise, please open an issue before continuing with this pull request.