diff --git a/httpx/_content.py b/httpx/_content.py index 6f479a0885..fd81d7e486 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -35,8 +35,25 @@ def __init__(self, stream: bytes) -> None: def __iter__(self) -> Iterator[bytes]: yield self._stream - async def __aiter__(self) -> AsyncIterator[bytes]: - yield self._stream + def __aiter__(self) -> AsyncIterator[bytes]: + return _ByteStreamAsyncIterator(self._stream) + + +class _ByteStreamAsyncIterator: + __slots__ = ("_stream",) + + def __init__(self, stream: bytes) -> None: + self._stream: bytes | None = stream + + def __aiter__(self) -> _ByteStreamAsyncIterator: + return self + + async def __anext__(self) -> bytes: + stream = self._stream + if stream is None: + raise StopAsyncIteration + self._stream = None # Consumed. + return stream class IteratorByteStream(SyncByteStream): diff --git a/tests/test_content.py b/tests/test_content.py index 9bfe983722..5885c7e719 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -32,6 +32,10 @@ async def test_bytes_content(): sync_content = b"".join(list(request.stream)) async_content = b"".join([part async for part in request.stream]) + # The async iterator returned by __aiter__ is itself an async iterator. + async_iterator = request.stream.__aiter__() + assert async_iterator.__aiter__() is async_iterator + assert request.headers == {"Host": "www.example.com", "Content-Length": "13"} assert sync_content == b"Hello, world!" assert async_content == b"Hello, world!"