diff --git a/CHANGES/10113.bugfix.rst b/CHANGES/10113.bugfix.rst new file mode 120000 index 00000000000..89cef58729f --- /dev/null +++ b/CHANGES/10113.bugfix.rst @@ -0,0 +1 @@ +10101.bugfix.rst \ No newline at end of file diff --git a/CHANGES/10125.bugfix.rst b/CHANGES/10125.bugfix.rst new file mode 100644 index 00000000000..4ece1e68d96 --- /dev/null +++ b/CHANGES/10125.bugfix.rst @@ -0,0 +1 @@ +Disabled zero copy writes in the ``StreamWriter`` -- by :user:`bdraco`. diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index c66fda3d8d0..edd19ed65da 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -90,7 +90,7 @@ def _writelines(self, chunks: Iterable[bytes]) -> None: transport = self._protocol.transport if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") - transport.writelines(chunks) + transport.write(b"".join(chunks)) async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 1177eaf7af8..28694d33dbd 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -4,6 +4,7 @@ import pathlib import sys from contextlib import suppress +from enum import Enum, auto from mimetypes import MimeTypes from stat import S_ISREG from types import MappingProxyType @@ -17,7 +18,6 @@ Optional, Set, Tuple, - cast, ) from . import hdrs @@ -66,6 +66,16 @@ } ) + +class _FileResponseResult(Enum): + """The result of the file response.""" + + SEND_FILE = auto() # Ie a regular file to send + NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file + PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed + NOT_MODIFIED = auto() # 304 Not Modified + + # Add custom pairs and clear the encodings map so guess_type ignores them. CONTENT_TYPES.encodings_map.clear() for content_type, extension in ADDITIONAL_CONTENT_TYPES.items(): @@ -163,10 +173,12 @@ async def _precondition_failed( self.content_length = 0 return await super().prepare(request) - def _open_file_path_stat_encoding( - self, accept_encoding: str - ) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]: - """Return the io object, stat result, and encoding. + def _make_response( + self, request: "BaseRequest", accept_encoding: str + ) -> Tuple[ + _FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str] + ]: + """Return the response result, io object, stat result, and encoding. If an uncompressed file is returned, the encoding is set to :py:data:`None`. @@ -174,6 +186,52 @@ def _open_file_path_stat_encoding( This method should be called from a thread executor since it calls os.stat which may block. """ + file_path, st, file_encoding = self._get_file_path_stat_encoding( + accept_encoding + ) + if not file_path: + return _FileResponseResult.NOT_ACCEPTABLE, None, st, None + + etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" + + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2 + if (ifmatch := request.if_match) is not None and not self._etag_match( + etag_value, ifmatch, weak=False + ): + return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding + + if ( + (unmodsince := request.if_unmodified_since) is not None + and ifmatch is None + and st.st_mtime > unmodsince.timestamp() + ): + return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding + + # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2 + if (ifnonematch := request.if_none_match) is not None and self._etag_match( + etag_value, ifnonematch, weak=True + ): + return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding + + if ( + (modsince := request.if_modified_since) is not None + and ifnonematch is None + and st.st_mtime <= modsince.timestamp() + ): + return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding + + fobj = file_path.open("rb") + with suppress(OSError): + # fstat() may not be available on all platforms + # Once we open the file, we want the fstat() to ensure + # the file has not changed between the first stat() + # and the open(). + st = os.stat(fobj.fileno()) + return _FileResponseResult.SEND_FILE, fobj, st, file_encoding + + def _get_file_path_stat_encoding( + self, accept_encoding: str + ) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]: file_path = self._path for file_extension, file_encoding in ENCODING_EXTENSIONS.items(): if file_encoding not in accept_encoding: @@ -184,27 +242,11 @@ def _open_file_path_stat_encoding( # Do not follow symlinks and ignore any non-regular files. st = compressed_path.lstat() if S_ISREG(st.st_mode): - fobj = compressed_path.open("rb") - with suppress(OSError): - # fstat() may not be available on all platforms - # Once we open the file, we want the fstat() to ensure - # the file has not changed between the first stat() - # and the open(). - st = os.stat(fobj.fileno()) - return fobj, st, file_encoding + return compressed_path, st, file_encoding # Fallback to the uncompressed file st = file_path.stat() - if not S_ISREG(st.st_mode): - return None, st, None - fobj = file_path.open("rb") - with suppress(OSError): - # fstat() may not be available on all platforms - # Once we open the file, we want the fstat() to ensure - # the file has not changed between the first stat() - # and the open(). - st = os.stat(fobj.fileno()) - return fobj, st, None + return file_path if S_ISREG(st.st_mode) else None, st, None async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: loop = asyncio.get_running_loop() @@ -212,8 +254,8 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() try: - fobj, st, file_encoding = await loop.run_in_executor( - None, self._open_file_path_stat_encoding, accept_encoding + response_result, fobj, st, file_encoding = await loop.run_in_executor( + None, self._make_response, request, accept_encoding ) except PermissionError: self.set_status(HTTPForbidden.status_code) @@ -224,24 +266,32 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter self.set_status(HTTPNotFound.status_code) return await super().prepare(request) - try: - # Forbid special files like sockets, pipes, devices, etc. - if not fobj or not S_ISREG(st.st_mode): - self.set_status(HTTPForbidden.status_code) - return await super().prepare(request) + # Forbid special files like sockets, pipes, devices, etc. + if response_result is _FileResponseResult.NOT_ACCEPTABLE: + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) + + if response_result is _FileResponseResult.PRE_CONDITION_FAILED: + return await self._precondition_failed(request) + + if response_result is _FileResponseResult.NOT_MODIFIED: + etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" + last_modified = st.st_mtime + return await self._not_modified(request, etag_value, last_modified) + assert fobj is not None + try: return await self._prepare_open_file(request, fobj, st, file_encoding) finally: - if fobj: - # We do not await here because we do not want to wait - # for the executor to finish before returning the response - # so the connection can begin servicing another request - # as soon as possible. - close_future = loop.run_in_executor(None, fobj.close) - # Hold a strong reference to the future to prevent it from being - # garbage collected before it completes. - _CLOSE_FUTURES.add(close_future) - close_future.add_done_callback(_CLOSE_FUTURES.remove) + # We do not await here because we do not want to wait + # for the executor to finish before returning the response + # so the connection can begin servicing another request + # as soon as possible. + close_future = loop.run_in_executor(None, fobj.close) + # Hold a strong reference to the future to prevent it from being + # garbage collected before it completes. + _CLOSE_FUTURES.add(close_future) + close_future.add_done_callback(_CLOSE_FUTURES.remove) async def _prepare_open_file( self, @@ -250,47 +300,13 @@ async def _prepare_open_file( st: os.stat_result, file_encoding: Optional[str], ) -> Optional[AbstractStreamWriter]: - etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" - last_modified = st.st_mtime - - # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2 - ifmatch = request.if_match - if ifmatch is not None and not self._etag_match( - etag_value, ifmatch, weak=False - ): - return await self._precondition_failed(request) - - unmodsince = request.if_unmodified_since - if ( - unmodsince is not None - and ifmatch is None - and st.st_mtime > unmodsince.timestamp() - ): - return await self._precondition_failed(request) - - # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2 - ifnonematch = request.if_none_match - if ifnonematch is not None and self._etag_match( - etag_value, ifnonematch, weak=True - ): - return await self._not_modified(request, etag_value, last_modified) - - modsince = request.if_modified_since - if ( - modsince is not None - and ifnonematch is None - and st.st_mtime <= modsince.timestamp() - ): - return await self._not_modified(request, etag_value, last_modified) - status = self._status - file_size = st.st_size - count = file_size - - start = None + file_size: int = st.st_size + file_mtime: float = st.st_mtime + count: int = file_size + start: Optional[int] = None - ifrange = request.if_range - if ifrange is None or st.st_mtime <= ifrange.timestamp(): + if (ifrange := request.if_range) is None or file_mtime <= ifrange.timestamp(): # If-Range header check: # condition = cached date >= last modification date # return 206 if True else 200. @@ -301,7 +317,7 @@ async def _prepare_open_file( try: rng = request.http_range start = rng.start - end = rng.stop + end: Optional[int] = rng.stop except ValueError: # https://tools.ietf.org/html/rfc7233: # A server generating a 416 (Range Not Satisfiable) response to @@ -318,7 +334,7 @@ async def _prepare_open_file( # If a range request has been made, convert start, end slice # notation into file pointer offset and count - if start is not None or end is not None: + if start is not None: if start < 0 and end is None: # return tail of file start += file_size if start < 0: @@ -375,26 +391,24 @@ async def _prepare_open_file( # compress. self._compression = False - self.etag = etag_value # type: ignore[assignment] - self.last_modified = st.st_mtime # type: ignore[assignment] + self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}" # type: ignore[assignment] + self.last_modified = file_mtime # type: ignore[assignment] self.content_length = count self._headers[hdrs.ACCEPT_RANGES] = "bytes" - real_start = cast(int, start) - if status == HTTPPartialContent.status_code: + real_start = start + assert real_start is not None self._headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format( real_start, real_start + count - 1, file_size ) # If we are sending 0 bytes calling sendfile() will throw a ValueError - if count == 0 or must_be_empty_body(request.method, self.status): + if count == 0 or must_be_empty_body(request.method, status): return await super().prepare(request) - if start: # be aware that start could be None or int=0 here. - offset = start - else: - offset = 0 + # be aware that start could be None or int=0 here. + offset = start or 0 return await self._sendfile(request, fobj, offset, count) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 69e89bb34c2..65afc05ae10 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -123,16 +123,15 @@ async def test_write_large_payload_deflate_compression_data_in_eof( assert transport.write.called # type: ignore[attr-defined] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] transport.write.reset_mock() # type: ignore[attr-defined] - assert not transport.writelines.called # type: ignore[attr-defined] # This payload compresses to 20447 bytes payload = b"".join( [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] ) await msg.write_eof(payload) - assert not transport.write.called # type: ignore[attr-defined] - assert transport.writelines.called # type: ignore[attr-defined] - chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] + chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] + + assert all(chunks) content = b"".join(chunks) assert zlib.decompress(content) == (b"data" * 4096) + payload @@ -202,7 +201,7 @@ async def test_write_payload_deflate_compression_chunked( await msg.write(b"data") await msg.write_eof() - chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @@ -238,7 +237,7 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof( await msg.write(b"data") await msg.write_eof(b"end") - chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @@ -257,16 +256,16 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof( # This payload compresses to 1111 bytes payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) await msg.write_eof(payload) - assert not transport.write.called # type: ignore[attr-defined] - chunks = [] - for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] - chunked_payload = list(write_lines_call[1][0])[1:] - chunked_payload.pop() - chunks.extend(chunked_payload) + compressed = [] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + chunked_body = b"".join(chunks) + split_body = chunked_body.split(b"\r\n") + while split_body: + if split_body.pop(0): + compressed.append(split_body.pop(0)) - assert all(chunks) - content = b"".join(chunks) + content = b"".join(compressed) assert zlib.decompress(content) == (b"data" * 4096) + payload