diff --git a/contree_cli/cli/file.py b/contree_cli/cli/file.py index 968754b..75b9706 100644 --- a/contree_cli/cli/file.py +++ b/contree_cli/cli/file.py @@ -328,37 +328,35 @@ def cmd_file_ls(args: FileListArgs) -> int | None: if args.until is not None: params["until"] = isoformat_datetime(args.until) - fetcher = PaginatedFetcher( + emitted = 0 + hit_limit = False + with PaginatedFetcher( client, "/v1/files", params, lambda body: json.loads(body).get("files", []), limit=args.limit, concurrency=CONTREE_CONCURRENCY, - ) - - emitted = 0 - hit_limit = False - for page in fetcher: - for entry in page: - if emitted >= args.limit: - hit_limit = True + ) as fetcher: + for page in fetcher: + for entry in page: + if emitted >= args.limit: + hit_limit = True + break + uuid_str = entry.get("uuid") + source = sources.get(uuid_str, "") if isinstance(uuid_str, str) else "" + if args.quiet: + formatter( + uuid=uuid_str, + sha256=entry.get("sha256", ""), + source=source, + ) + else: + formatter(**{**entry, "source": source}) + emitted += 1 + formatter.flush() + if hit_limit: break - uuid_str = entry.get("uuid") - source = sources.get(uuid_str, "") if isinstance(uuid_str, str) else "" - if args.quiet: - formatter( - uuid=uuid_str, - sha256=entry.get("sha256", ""), - source=source, - ) - else: - formatter(**{**entry, "source": source}) - emitted += 1 - formatter.flush() - if hit_limit: - fetcher.stop() - break if hit_limit: logger.warning( diff --git a/contree_cli/cli/images.py b/contree_cli/cli/images.py index c142281..c60965d 100644 --- a/contree_cli/cli/images.py +++ b/contree_cli/cli/images.py @@ -272,28 +272,26 @@ def cmd_images(args: ImagesArgs) -> None: if args.until is not None: base_params["until"] = isoformat_datetime(args.until) - fetcher = PaginatedFetcher( + emitted = 0 + hit_limit = False + with PaginatedFetcher( client, "/v1/images", base_params, lambda body: json.loads(body)["images"], limit=args.limit, concurrency=CONTREE_CONCURRENCY, - ) - - emitted = 0 - hit_limit = False - for page in fetcher: - for image in page: - if emitted >= args.limit: - hit_limit = True + ) as fetcher: + for page in fetcher: + for image in page: + if emitted >= args.limit: + hit_limit = True + break + formatter(**image) + emitted += 1 + formatter.flush() + if hit_limit: break - formatter(**image) - emitted += 1 - formatter.flush() - if hit_limit: - fetcher.stop() - break if hit_limit: logger.warning( diff --git a/contree_cli/cli/operation.py b/contree_cli/cli/operation.py index 32e4399..1d09697 100644 --- a/contree_cli/cli/operation.py +++ b/contree_cli/cli/operation.py @@ -393,31 +393,29 @@ def cmd_list(args: ListArgs) -> None: base_params["until"] = isoformat_datetime(args.until) limit = args.show_max - fetcher = PaginatedFetcher( + emitted = 0 + hit_limit = False + with PaginatedFetcher( client, "/v1/operations", base_params, json.loads, limit=limit, concurrency=CONTREE_CONCURRENCY, - ) - - emitted = 0 - hit_limit = False - for page in fetcher: - for op in page: - if limit is not None and emitted >= limit: - hit_limit = True + ) as fetcher: + for page in fetcher: + for op in page: + if limit is not None and emitted >= limit: + hit_limit = True + break + if args.quiet: + print(op["uuid"]) + else: + formatter(**op) + emitted += 1 + formatter.flush() + if hit_limit: break - if args.quiet: - print(op["uuid"]) - else: - formatter(**op) - emitted += 1 - formatter.flush() - if hit_limit: - fetcher.stop() - break if hit_limit: logger.warning( diff --git a/contree_cli/client.py b/contree_cli/client.py index e49dd2f..6648acc 100644 --- a/contree_cli/client.py +++ b/contree_cli/client.py @@ -619,6 +619,17 @@ def stop(self) -> None: """Signal that the caller has seen enough; skip pending fetches.""" self._stop.set() + def __enter__(self) -> PaginatedFetcher: + return self + + def __exit__(self, *_: object) -> None: + # Setting the stop event short-circuits any worker that hasn't + # started yet and prevents the iterator's refill from enqueueing + # more fetches. Callers wrap iteration in `with PaginatedFetcher(...)` + # so they don't have to remember an explicit `stop()` after + # breaking out of a paged loop. + self.stop() + def _fetch(self, offset: int) -> list[dict[str, Any]]: if self._stop.is_set(): return [] diff --git a/tests/test_client.py b/tests/test_client.py index aefd871..23987a2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -668,11 +668,12 @@ def test_small_limit_uses_capped_page_size_in_request(self, contree_client): # `limit=5` must request `limit=6` (capped page size + 1 for the # truncation probe), not the default 1000. contree_client.respond_json({"items": [{"i": i} for i in range(6)]}) - f = self.make_fetcher(contree_client, limit=5) - pages_iter = iter(f) - first = next(pages_iter) - assert len(first) == 6 - f.stop() # mirror the real caller, which calls stop() after hitting limit + with self.make_fetcher(contree_client, limit=5) as f: + pages_iter = iter(f) + first = next(pages_iter) + assert len(first) == 6 + # Context manager exit calls stop() automatically; mirrors the + # real caller which breaks out of the loop after hitting limit. with contextlib.suppress(StopIteration): next(pages_iter) req = contree_client.get_request(0)