Skip to content
Merged

Trio #138

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ jobs:
- name: "Install dependencies"
run: "scripts/install"
- name: "Run tests"
run: "scripts/test"
run: "scripts/test"
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
-e .

trio==0.33.0

# Build...
build==1.2.2

# Test...
mypy==1.15.0
pytest==8.3.5
pytest-cov==6.1.1
pytest-trio==0.8.0

# Sync & Async mirroring...
unasync==0.6.0
Expand Down
3 changes: 1 addition & 2 deletions scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ if [ -d 'venv' ] ; then
export PREFIX="venv/bin/"
fi

${PREFIX}mypy src/httpx
${PREFIX}mypy src/ahttpx
${PREFIX}pytest --cov src/httpx tests
${PREFIX}pytest --cov src/ahttpx tests/test_ahttpx
48 changes: 48 additions & 0 deletions scripts/unasync
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,51 @@ unasync.unasync_files(
),
]
)


unasync.unasync_files(
fpath_list = [
"tests/test_ahttpx/test_client.py",
"tests/test_ahttpx/test_content.py",
"tests/test_ahttpx/test_headers.py",
"tests/test_ahttpx/test_network.py",
"tests/test_ahttpx/test_parsers.py",
"tests/test_ahttpx/test_pool.py",
"tests/test_ahttpx/test_quickstart.py",
"tests/test_ahttpx/test_request.py",
"tests/test_ahttpx/test_response.py",
"tests/test_ahttpx/test_streams.py",
"tests/test_ahttpx/test_urlencode.py",
"tests/test_ahttpx/test_urls.py",
],
rules = [
unasync.Rule(
"tests/test_ahttpx/",
"tests/test_httpx/",
additional_replacements={"ahttpx": "httpx"}
),
]
)


for path in [
"tests/test_httpx/test_client.py",
"tests/test_httpx/test_content.py",
"tests/test_httpx/test_headers.py",
"tests/test_httpx/test_network.py",
"tests/test_httpx/test_parsers.py",
"tests/test_httpx/test_pool.py",
"tests/test_httpx/test_quickstart.py",
"tests/test_httpx/test_request.py",
"tests/test_httpx/test_response.py",
"tests/test_httpx/test_streams.py",
"tests/test_httpx/test_urlencode.py",
"tests/test_httpx/test_urls.py",
]:
with open(path, "r") as fin:
lines = fin.readlines()

lines = [line for line in lines if line != "@pytest.mark.trio\n"]

with open(path, "w") as fout:
fout.writelines(lines)
79 changes: 47 additions & 32 deletions src/ahttpx/_network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import ssl
import types
import typing

import trio
import certifi

from ._streams import Stream
Expand All @@ -13,39 +13,37 @@

class NetworkStream(Stream):
def __init__(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = ''
self, trio_stream: trio.abc.Stream, address: str = ''
) -> None:
self._reader = reader
self._writer = writer
self._trio_stream = trio_stream
self._address = address
self._tls = False
self._closed = False

async def read(self, size: int = -1) -> bytes:
if size < 0:
size = 64 * 1024
return await self._reader.read(size)
return await self._trio_stream.receive_some(size)

async def write(self, buffer: bytes) -> None:
self._writer.write(buffer)
await self._writer.drain()
await self._trio_stream.send_all(buffer)

async def close(self) -> None:
if not self._closed:
self._writer.close()
await self._writer.wait_closed()
# Close the NetworkStream.
# If the stream is already closed this is a checkpointed no-op.
try:
await self._trio_stream.aclose()
finally:
self._closed = True

def __repr__(self):
description = ""
description += " TLS" if self._tls else ""
description += " CLOSED" if self._closed else ""
return f"<NetworkStream [{self._address!r}{description}]>"
return f"<NetworkStream [{self._address}{description}]>"

def __del__(self):
if not self._closed:
import warnings
warnings.warn("NetworkStream was garbage collected without being closed.")
warnings.warn(f"{self!r} was garbage collected without being closed.")

# Context managed usage...
async def __aenter__(self) -> "NetworkStream":
Expand All @@ -61,13 +59,17 @@ async def __aexit__(


class NetworkServer:
def __init__(self, host: str, port: int, server: asyncio.Server):
def __init__(self, host: str, port: int, handler, listeners: list[trio.SocketListener]):
self.host = host
self.port = port
self._server = server
self._handler = handler
self._listeners = listeners

# Context managed usage...
async def __aenter__(self) -> "NetworkServer":
self._nursery_manager = trio.open_nursery()
self._nursery = await self._nursery_manager.__aenter__()
self._nursery.start_soon(trio.serve_listeners, self._handler, self._listeners)
return self

async def __aexit__(
Expand All @@ -76,8 +78,8 @@ async def __aexit__(
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
):
self._server.close()
await self._server.wait_closed()
self._nursery.cancel_scope.cancel()
await self._nursery_manager.__aexit__(exc_type, exc_value, traceback)


class NetworkBackend:
Expand All @@ -92,29 +94,42 @@ async def connect(self, host: str, port: int) -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
return NetworkStream(reader, writer, address=address)
trio_stream = await trio.open_tcp_stream(host, port)
return NetworkStream(trio_stream, address=address)

async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
await writer.start_tls(self._ssl_ctx, server_hostname=hostname)
return NetworkStream(reader, writer, address=address)
trio_stream = await trio.open_tcp_stream(host, port)

# Establish SSL over TCP
hostname = hostname or host
ssl_stream = trio.SSLStream(trio_stream, ssl_context=self._ssl_ctx, server_hostname=hostname)
await ssl_stream.do_handshake()

return NetworkStream(ssl_stream, address=address)

async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
async def callback(reader, writer):
stream = NetworkStream(reader, writer)
await handler(stream)
async def callback(trio_stream):
stream = NetworkStream(trio_stream, address=f"{host}:{port}")
try:
await handler(stream)
finally:
await stream.close()

server = await asyncio.start_server(callback, host, port)
return NetworkServer(host, port, server)
listeners = await trio.open_tcp_listeners(port=port, host=host)
return NetworkServer(host, port, callback, listeners)

def __repr__(self):
return f"<NetworkBackend [trio]>"


Semaphore = asyncio.Semaphore
Lock = asyncio.Lock
timeout = asyncio.timeout
sleep = asyncio.sleep
Semaphore = trio.Semaphore
Lock = trio.Lock
timeout = trio.move_on_after
sleep = trio.sleep
22 changes: 22 additions & 0 deletions src/ahttpx/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ async def send_body(self, body: bytes) -> None:
# Handle body close
self.send_state = State.DONE

async def recv_close(self) -> bool:
# ...
if self.is_closed():
return True

if await self.parser.read_eof():
await self.close()
return True
return False

async def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
"""
Receive the initial request method line:
Expand Down Expand Up @@ -463,6 +473,18 @@ async def read(self, size: int) -> bytes:
self._push_back(bytes(push_back))
return bytes(buffer)

async def read_eof(self) -> bool:
"""
Attempt to read the closing EOF.
Return True if the stream is EOF, or False otherwise.
"""
if not self._buffer:
chunk = await self._read_some()
if not chunk:
return True
self._push_back(chunk)
return False

async def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
"""
Read and return bytes from the stream, delimited by marker.
Expand Down
8 changes: 4 additions & 4 deletions src/ahttpx/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, stream, endpoint):
# API entry points...
async def handle_requests(self):
try:
while not self._parser.is_closed():
while not await self._parser.recv_close():
method, url, headers = await self._recv_head()
stream = HTTPStream(self._recv_body, self._complete)
# TODO: Handle endpoint exceptions
Expand All @@ -43,13 +43,13 @@ async def handle_requests(self):
except Exception:
logger.error("Internal Server Error", exc_info=True)
content = Text("Internal Server Error")
err = Response(code=500, content=content)
err = Response(500, content=content)
await self._send_head(err)
await self._send_body(err)
else:
await self._send_head(response)
await self._send_body(response)
except Exception:
except BaseException:
logger.error("Internal Server Error", exc_info=True)

async def close(self):
Expand Down Expand Up @@ -89,7 +89,7 @@ async def _send_body(self, response: Response):

# Start it all over again...
async def _complete(self):
await self._parser.complete
await self._parser.complete()
self._idle_expiry = time.monotonic() + self._keepalive_duration


Expand Down
6 changes: 5 additions & 1 deletion src/httpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkS
self._max_workers = 5
self._executor = None
self._thread = None
self._streams = list[NetworkStream]
self._streams: list[NetworkStream] = []

@property
def host(self):
Expand All @@ -176,6 +176,8 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
for stream in self._streams:
stream.close()
self.listener.close()
self._executor.shutdown(wait=True)

Expand All @@ -185,9 +187,11 @@ def _serve(self):

def _handler(self, stream):
try:
self._streams.append(stream)
self.handler(stream)
finally:
stream.close()
self._streams.remove(stream)


class NetworkBackend:
Expand Down
22 changes: 22 additions & 0 deletions src/httpx/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ def send_body(self, body: bytes) -> None:
# Handle body close
self.send_state = State.DONE

def recv_close(self) -> bool:
# ...
if self.is_closed():
return True

if self.parser.read_eof():
self.close()
return True
return False

def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
"""
Receive the initial request method line:
Expand Down Expand Up @@ -463,6 +473,18 @@ def read(self, size: int) -> bytes:
self._push_back(bytes(push_back))
return bytes(buffer)

def read_eof(self) -> bool:
"""
Attempt to read the closing EOF.
Return True if the stream is EOF, or False otherwise.
"""
if not self._buffer:
chunk = self._read_some()
if not chunk:
return True
self._push_back(chunk)
return False

def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
"""
Read and return bytes from the stream, delimited by marker.
Expand Down
Loading