From bc11725f1ca49660cc3c5e192f5dd22d286ca225 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Thu, 11 Jun 2026 16:18:31 +0530 Subject: [PATCH 01/19] increase uvicorn keep alive --- singlestoredb/apps/_python_udfs.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index e45718dec..5c8cf4a73 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -61,11 +61,18 @@ async def run_udf_app( f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', ) + # Increase the timeout so the uvicorn server is not the one closing idle connections. + # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). + keep_alive_timeout = int( + os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), + ) + config = uvicorn.Config( app, host='0.0.0.0', port=app_config.listen_port, log_config=app.get_uvicorn_log_config(), + timeout_keep_alive=keep_alive_timeout, ) # Register the functions only if the app is running interactively. From c6e96b9cc3c59d8b84139b4f884a833c6192b9c8 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 10:07:34 -0500 Subject: [PATCH 02/19] Fix async UDF event loop starvation under heavy load in Jupyter Async UDFs were running directly in uvicorn's event loop via asyncio.create_task, competing with connection handling under heavy concurrent load. This caused unresponsiveness when running from Jupyter notebooks where the event loop is shared. The fix introduces a dedicated event loop in a background thread for async UDF execution. Coroutines are submitted via run_coroutine_threadsafe() and awaited from the server loop, isolating UDF work from HTTP I/O while preserving cooperative async scheduling between UDFs. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 38 +++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 331828d17..3929e0466 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1006,6 +1006,15 @@ def __init__( self.log_level = log_level self.disable_metrics = disable_metrics + # Dedicated event loop for async UDF execution, isolated from the server loop + self._udf_loop = asyncio.new_event_loop() + self._udf_thread = threading.Thread( + target=self._udf_loop.run_forever, + daemon=True, + name='async-udf-loop', + ) + self._udf_thread.start() + # Configure logging self._configure_logging() @@ -1039,6 +1048,11 @@ def _configure_logging(self) -> None: # Prevent propagation to avoid duplicate or differently formatted messages self.logger.propagate = False + def shutdown(self) -> None: + """Shut down the dedicated UDF event loop.""" + self._udf_loop.call_soon_threadsafe(self._udf_loop.stop) + self._udf_thread.join(timeout=5) + def get_uvicorn_log_config(self) -> Dict[str, Any]: """ Create uvicorn log config that matches the Application's logging format. @@ -1195,15 +1209,23 @@ async def __call__( func_info['colspec'], b''.join(data), ) - func_task = asyncio.create_task( - func(cancel_event, call_timer, *inputs) - if func_info['is_async'] - else to_thread( - lambda: asyncio.run( - func(cancel_event, call_timer, *inputs), + func_task: 'asyncio.Task[Any]' + if func_info['is_async']: + future = asyncio.run_coroutine_threadsafe( + func(cancel_event, call_timer, *inputs), + self._udf_loop, + ) + func_task = asyncio.create_task( + asyncio.wrap_future(future), # type: ignore[arg-type] + ) + else: + func_task = asyncio.create_task( + to_thread( + lambda: asyncio.run( + func(cancel_event, call_timer, *inputs), + ), ), - ), - ) + ) disconnect_task = asyncio.create_task( asyncio.sleep(int(1e9)) if ignore_cancel else cancel_on_disconnect(receive), From 16b9426e95462fd4095e0fdc680fde661d6d7b24 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 10:11:01 -0500 Subject: [PATCH 03/19] Ensure proper cancellation of async UDFs in dedicated loop Cancel the concurrent.futures.Future in the UDF loop on disconnect/timeout so the coroutine is interrupted promptly, not just at the next cancel_on_event row check. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 3929e0466..fa44348dd 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,6 +24,7 @@ """ import argparse import asyncio +import concurrent.futures import contextvars import dataclasses import datetime @@ -1210,13 +1211,14 @@ async def __call__( ) func_task: 'asyncio.Task[Any]' + udf_future: 'Optional[concurrent.futures.Future[Any]]' = None if func_info['is_async']: - future = asyncio.run_coroutine_threadsafe( + udf_future = asyncio.run_coroutine_threadsafe( func(cancel_event, call_timer, *inputs), self._udf_loop, ) func_task = asyncio.create_task( - asyncio.wrap_future(future), # type: ignore[arg-type] + asyncio.wrap_future(udf_future), # type: ignore[arg-type] ) else: func_task = asyncio.create_task( @@ -1246,12 +1248,16 @@ async def __call__( for task in done: if task is disconnect_task: cancel_event.set() + if udf_future is not None: + udf_future.cancel() raise asyncio.CancelledError( 'Function call was cancelled by client disconnect', ) elif task is timeout_task: cancel_event.set() + if udf_future is not None: + udf_future.cancel() raise asyncio.TimeoutError( 'Function call was cancelled due to timeout', ) From 8bed545f6d40a0a40cab7c513f929b68e99107e0 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 10:19:46 -0500 Subject: [PATCH 04/19] Fix create_task expecting coroutine, use ensure_future for wrapped future asyncio.create_task() requires a coroutine but asyncio.wrap_future() returns a Future. Use asyncio.ensure_future() which accepts both. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index fa44348dd..ce0c39975 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1217,8 +1217,8 @@ async def __call__( func(cancel_event, call_timer, *inputs), self._udf_loop, ) - func_task = asyncio.create_task( - asyncio.wrap_future(udf_future), # type: ignore[arg-type] + func_task = asyncio.ensure_future( + asyncio.wrap_future(udf_future), ) else: func_task = asyncio.create_task( From 302722792ea6c57313efe2b6ebdafeded0ce0427 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 10:43:14 -0500 Subject: [PATCH 05/19] Call Application.shutdown() when replacing UDF server Prevents UDF event loop thread leaks when run_udf_app() is called repeatedly in Jupyter notebooks. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/apps/_python_udfs.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index 5c8cf4a73..84581945b 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -10,8 +10,9 @@ if typing.TYPE_CHECKING: from ._uvicorn_util import AwaitableUvicornServer -# Keep track of currently running server +# Keep track of currently running server and app _running_server: 'typing.Optional[AwaitableUvicornServer]' = None +_running_app: typing.Optional[Application] = None # Maximum number of UDFs allowed MAX_UDFS_LIMIT = 10 @@ -21,7 +22,7 @@ async def run_udf_app( log_level: str = 'error', kill_existing_app_server: bool = True, ) -> UdfConnectionInfo: - global _running_server + global _running_server, _running_app from ._uvicorn_util import AwaitableUvicornServer try: @@ -38,6 +39,9 @@ async def run_udf_app( if _running_server is not None: await _running_server.shutdown() _running_server = None + if _running_app is not None: + _running_app.shutdown() + _running_app = None # Kill if any other process is occupying the port kill_process_by_port(app_config.listen_port) @@ -79,6 +83,7 @@ async def run_udf_app( if app_config.running_interactively: app.register_functions(replace=True) + _running_app = app _running_server = AwaitableUvicornServer(config) asyncio.create_task(_running_server.serve()) await _running_server.wait_for_startup() From 6f2f8abd572f8e51b96a65c3e4e8a38c80397582 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 11:23:28 -0500 Subject: [PATCH 06/19] Propagate cancellation to UDF loop and prevent thread leak on failure - Cancel udf_future when func_task is in pending set after asyncio.wait - Cancel udf_future in finally block to ensure cleanup on any exit path - Wrap post-construction code in try/except to call app.shutdown() if validation, config, or registration fails after Application is created Co-Authored-By: Claude Opus 4.6 --- singlestoredb/apps/_python_udfs.py | 61 +++++++++++++++++------------ singlestoredb/functions/ext/asgi.py | 6 +++ 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index 84581945b..b3f8002f8 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -58,35 +58,46 @@ async def run_udf_app( log_level=log_level, ) - if not app.endpoints: - raise ValueError('You must define at least one function.') - if len(app.endpoints) > MAX_UDFS_LIMIT: - raise ValueError( - f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', - ) + try: + if not app.endpoints: + raise ValueError('You must define at least one function.') + if len(app.endpoints) > MAX_UDFS_LIMIT: + raise ValueError( + f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', + ) - # Increase the timeout so the uvicorn server is not the one closing idle connections. - # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). - keep_alive_timeout = int( - os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), - ) + config = uvicorn.Config( + app, + host='0.0.0.0', + port=app_config.listen_port, + log_config=app.get_uvicorn_log_config(), + ) - config = uvicorn.Config( - app, - host='0.0.0.0', - port=app_config.listen_port, - log_config=app.get_uvicorn_log_config(), - timeout_keep_alive=keep_alive_timeout, - ) + # Increase the timeout so the uvicorn server is not the one closing idle connections. + # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). + keep_alive_timeout = int( + os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), + ) - # Register the functions only if the app is running interactively. - if app_config.running_interactively: - app.register_functions(replace=True) + config = uvicorn.Config( + app, + host='0.0.0.0', + port=app_config.listen_port, + log_config=app.get_uvicorn_log_config(), + timeout_keep_alive=keep_alive_timeout, + ) - _running_app = app - _running_server = AwaitableUvicornServer(config) - asyncio.create_task(_running_server.serve()) - await _running_server.wait_for_startup() + # Register the functions only if the app is running interactively. + if app_config.running_interactively: + app.register_functions(replace=True) + + _running_app = app + _running_server = AwaitableUvicornServer(config) + asyncio.create_task(_running_server.serve()) + await _running_server.wait_for_startup() + except Exception: + app.shutdown() + raise print(f'Python UDF registered at {base_url}') diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index ce0c39975..ce9ab8370 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1244,6 +1244,9 @@ async def __call__( ) await cancel_all_tasks(pending) + if func_task in pending and udf_future is not None: + cancel_event.set() + udf_future.cancel() for task in done: if task is disconnect_task: @@ -1320,6 +1323,9 @@ async def __call__( await send(self.error_response_dict) finally: + if udf_future is not None: + cancel_event.set() + udf_future.cancel() await cancel_all_tasks(all_tasks) # Handle api reflection From 6db2fc17f025667dfc888d55ee9095ce781bd0bc Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 12:03:30 -0500 Subject: [PATCH 07/19] Fix udf_future NameError and lazily initialize UDF event loop - Move udf_future initialization before input_handler['load']() to prevent NameError in finally block if parsing raises - Lazily create UDF event loop on first async UDF invocation instead of unconditionally in __init__, avoiding wasted resources for sync-only or metadata-only usage - Guard shutdown() against None loop/thread Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 32 ++++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index ce9ab8370..2519d9486 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1007,14 +1007,8 @@ def __init__( self.log_level = log_level self.disable_metrics = disable_metrics - # Dedicated event loop for async UDF execution, isolated from the server loop - self._udf_loop = asyncio.new_event_loop() - self._udf_thread = threading.Thread( - target=self._udf_loop.run_forever, - daemon=True, - name='async-udf-loop', - ) - self._udf_thread.start() + self._udf_loop: Optional[asyncio.AbstractEventLoop] = None + self._udf_thread: Optional[threading.Thread] = None # Configure logging self._configure_logging() @@ -1049,10 +1043,24 @@ def _configure_logging(self) -> None: # Prevent propagation to avoid duplicate or differently formatted messages self.logger.propagate = False + def _get_udf_loop(self) -> asyncio.AbstractEventLoop: + """Get or create the dedicated UDF event loop.""" + if self._udf_loop is None: + self._udf_loop = asyncio.new_event_loop() + self._udf_thread = threading.Thread( + target=self._udf_loop.run_forever, + daemon=True, + name='async-udf-loop', + ) + self._udf_thread.start() + return self._udf_loop + def shutdown(self) -> None: """Shut down the dedicated UDF event loop.""" - self._udf_loop.call_soon_threadsafe(self._udf_loop.stop) - self._udf_thread.join(timeout=5) + if self._udf_loop is not None: + self._udf_loop.call_soon_threadsafe(self._udf_loop.stop) + if self._udf_thread is not None: + self._udf_thread.join(timeout=5) def get_uvicorn_log_config(self) -> Dict[str, Any]: """ @@ -1202,6 +1210,7 @@ async def __call__( try: all_tasks = [] result = [] + udf_future: 'Optional[concurrent.futures.Future[Any]]' = None cancel_event = threading.Event() @@ -1211,11 +1220,10 @@ async def __call__( ) func_task: 'asyncio.Task[Any]' - udf_future: 'Optional[concurrent.futures.Future[Any]]' = None if func_info['is_async']: udf_future = asyncio.run_coroutine_threadsafe( func(cancel_event, call_timer, *inputs), - self._udf_loop, + self._get_udf_loop(), ) func_task = asyncio.ensure_future( asyncio.wrap_future(udf_future), From 2b65cd0e44a4df433d563edcf4b82578c19c6549 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 14 May 2026 15:18:49 -0500 Subject: [PATCH 08/19] Reset UDF loop state in shutdown() to allow safe reuse After stopping the event loop and joining the thread, set both _udf_loop and _udf_thread back to None so that _get_udf_loop() can safely recreate them if called after shutdown. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 2519d9486..ba3b17bdc 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1061,6 +1061,8 @@ def shutdown(self) -> None: self._udf_loop.call_soon_threadsafe(self._udf_loop.stop) if self._udf_thread is not None: self._udf_thread.join(timeout=5) + self._udf_loop = None + self._udf_thread = None def get_uvicorn_log_config(self) -> Dict[str, Any]: """ From 1cffe175e8989d3d974db075cd9c7d8cb8a1c065 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 18 May 2026 09:50:57 -0500 Subject: [PATCH 09/19] Use thread-per-request for async UDFs instead of shared event loop The dedicated shared event loop still caused starvation under concurrent async UDF calls. Switch to the same model used by sync UDFs: each request gets its own thread with asyncio.run(), eliminating loop contention. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/apps/_python_udfs.py | 68 +++++++++++------------------ singlestoredb/functions/ext/asgi.py | 59 ++++--------------------- 2 files changed, 34 insertions(+), 93 deletions(-) diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index b3f8002f8..5c8cf4a73 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -10,9 +10,8 @@ if typing.TYPE_CHECKING: from ._uvicorn_util import AwaitableUvicornServer -# Keep track of currently running server and app +# Keep track of currently running server _running_server: 'typing.Optional[AwaitableUvicornServer]' = None -_running_app: typing.Optional[Application] = None # Maximum number of UDFs allowed MAX_UDFS_LIMIT = 10 @@ -22,7 +21,7 @@ async def run_udf_app( log_level: str = 'error', kill_existing_app_server: bool = True, ) -> UdfConnectionInfo: - global _running_server, _running_app + global _running_server from ._uvicorn_util import AwaitableUvicornServer try: @@ -39,9 +38,6 @@ async def run_udf_app( if _running_server is not None: await _running_server.shutdown() _running_server = None - if _running_app is not None: - _running_app.shutdown() - _running_app = None # Kill if any other process is occupying the port kill_process_by_port(app_config.listen_port) @@ -58,46 +54,34 @@ async def run_udf_app( log_level=log_level, ) - try: - if not app.endpoints: - raise ValueError('You must define at least one function.') - if len(app.endpoints) > MAX_UDFS_LIMIT: - raise ValueError( - f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', - ) - - config = uvicorn.Config( - app, - host='0.0.0.0', - port=app_config.listen_port, - log_config=app.get_uvicorn_log_config(), + if not app.endpoints: + raise ValueError('You must define at least one function.') + if len(app.endpoints) > MAX_UDFS_LIMIT: + raise ValueError( + f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', ) - # Increase the timeout so the uvicorn server is not the one closing idle connections. - # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). - keep_alive_timeout = int( - os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), - ) + # Increase the timeout so the uvicorn server is not the one closing idle connections. + # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). + keep_alive_timeout = int( + os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), + ) - config = uvicorn.Config( - app, - host='0.0.0.0', - port=app_config.listen_port, - log_config=app.get_uvicorn_log_config(), - timeout_keep_alive=keep_alive_timeout, - ) + config = uvicorn.Config( + app, + host='0.0.0.0', + port=app_config.listen_port, + log_config=app.get_uvicorn_log_config(), + timeout_keep_alive=keep_alive_timeout, + ) + + # Register the functions only if the app is running interactively. + if app_config.running_interactively: + app.register_functions(replace=True) - # Register the functions only if the app is running interactively. - if app_config.running_interactively: - app.register_functions(replace=True) - - _running_app = app - _running_server = AwaitableUvicornServer(config) - asyncio.create_task(_running_server.serve()) - await _running_server.wait_for_startup() - except Exception: - app.shutdown() - raise + _running_server = AwaitableUvicornServer(config) + asyncio.create_task(_running_server.serve()) + await _running_server.wait_for_startup() print(f'Python UDF registered at {base_url}') diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index ba3b17bdc..b22f1c262 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,7 +24,6 @@ """ import argparse import asyncio -import concurrent.futures import contextvars import dataclasses import datetime @@ -1007,9 +1006,6 @@ def __init__( self.log_level = log_level self.disable_metrics = disable_metrics - self._udf_loop: Optional[asyncio.AbstractEventLoop] = None - self._udf_thread: Optional[threading.Thread] = None - # Configure logging self._configure_logging() @@ -1043,27 +1039,6 @@ def _configure_logging(self) -> None: # Prevent propagation to avoid duplicate or differently formatted messages self.logger.propagate = False - def _get_udf_loop(self) -> asyncio.AbstractEventLoop: - """Get or create the dedicated UDF event loop.""" - if self._udf_loop is None: - self._udf_loop = asyncio.new_event_loop() - self._udf_thread = threading.Thread( - target=self._udf_loop.run_forever, - daemon=True, - name='async-udf-loop', - ) - self._udf_thread.start() - return self._udf_loop - - def shutdown(self) -> None: - """Shut down the dedicated UDF event loop.""" - if self._udf_loop is not None: - self._udf_loop.call_soon_threadsafe(self._udf_loop.stop) - if self._udf_thread is not None: - self._udf_thread.join(timeout=5) - self._udf_loop = None - self._udf_thread = None - def get_uvicorn_log_config(self) -> Dict[str, Any]: """ Create uvicorn log config that matches the Application's logging format. @@ -1212,7 +1187,6 @@ async def __call__( try: all_tasks = [] result = [] - udf_future: 'Optional[concurrent.futures.Future[Any]]' = None cancel_event = threading.Event() @@ -1221,23 +1195,13 @@ async def __call__( func_info['colspec'], b''.join(data), ) - func_task: 'asyncio.Task[Any]' - if func_info['is_async']: - udf_future = asyncio.run_coroutine_threadsafe( - func(cancel_event, call_timer, *inputs), - self._get_udf_loop(), - ) - func_task = asyncio.ensure_future( - asyncio.wrap_future(udf_future), - ) - else: - func_task = asyncio.create_task( - to_thread( - lambda: asyncio.run( - func(cancel_event, call_timer, *inputs), - ), + func_task = asyncio.create_task( + to_thread( + lambda: asyncio.run( + func(cancel_event, call_timer, *inputs), ), - ) + ), + ) disconnect_task = asyncio.create_task( asyncio.sleep(int(1e9)) if ignore_cancel else cancel_on_disconnect(receive), @@ -1254,23 +1218,18 @@ async def __call__( ) await cancel_all_tasks(pending) - if func_task in pending and udf_future is not None: + if func_task in pending: cancel_event.set() - udf_future.cancel() for task in done: if task is disconnect_task: cancel_event.set() - if udf_future is not None: - udf_future.cancel() raise asyncio.CancelledError( 'Function call was cancelled by client disconnect', ) elif task is timeout_task: cancel_event.set() - if udf_future is not None: - udf_future.cancel() raise asyncio.TimeoutError( 'Function call was cancelled due to timeout', ) @@ -1333,9 +1292,7 @@ async def __call__( await send(self.error_response_dict) finally: - if udf_future is not None: - cancel_event.set() - udf_future.cancel() + cancel_event.set() await cancel_all_tasks(all_tasks) # Handle api reflection From 19d00da4b00deb6b7985b0b7a6132b4f06312118 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 18 May 2026 10:12:09 -0500 Subject: [PATCH 10/19] Add cancellable wrapper for responsive async UDF cancellation Wraps the inner coroutine in _cancellable_run which polls cancel_event and raises CancelledError at the next await (~100ms), ensuring vector UDFs respect disconnect/timeout signals without waiting for completion. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index b22f1c262..e426a8005 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -113,6 +113,28 @@ async def to_thread( return await loop.run_in_executor(None, func_call) +async def _poll_cancel(cancel_event: threading.Event) -> None: + while not cancel_event.is_set(): + await asyncio.sleep(0.1) + + +async def _cancellable_run( + cancel_event: threading.Event, + coro: Any, +) -> Any: + task = asyncio.create_task(coro) + cancel_check = asyncio.create_task(_poll_cancel(cancel_event)) + done, pending = await asyncio.wait( + [task, cancel_check], return_when=asyncio.FIRST_COMPLETED, + ) + for p in pending: + p.cancel() + if cancel_check in done: + task.cancel() + raise asyncio.CancelledError() + return task.result() + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -1198,7 +1220,10 @@ async def __call__( func_task = asyncio.create_task( to_thread( lambda: asyncio.run( - func(cancel_event, call_timer, *inputs), + _cancellable_run( + cancel_event, + func(cancel_event, call_timer, *inputs), + ), ), ), ) From e260b198ca179090e6d27b957b3e790d8b82eb7c Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 19 May 2026 09:00:21 -0500 Subject: [PATCH 11/19] Fix event loop closed error and add comprehensive UDF dispatch tests Replace asyncio.run() with _run_with_graceful_shutdown() that drains pending callbacks before closing the loop, preventing RuntimeError from httpx/anyio TLS cleanup in async UDFs calling OpenAI/LangChain APIs. Add 17 unit tests covering graceful shutdown, cancellation timing, exception propagation, context variable isolation, and concurrent safety. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/asgi.py | 30 ++- singlestoredb/tests/test_udf_event_loop.py | 296 +++++++++++++++++++++ 2 files changed, 325 insertions(+), 1 deletion(-) create mode 100644 singlestoredb/tests/test_udf_event_loop.py diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e426a8005..329cae64f 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -135,6 +135,34 @@ async def _cancellable_run( return task.result() +def _run_with_graceful_shutdown(coro: Any) -> Any: + """Run a coroutine in a new event loop, draining callbacks before close. + + Unlike asyncio.run(), this prevents 'Event loop is closed' errors from + libraries (httpx/anyio) that schedule cleanup callbacks during teardown. + """ + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + finally: + try: + pending = asyncio.all_tasks(loop) + if pending: + for task in pending: + task.cancel() + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True), + ) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + loop.call_soon(loop.stop) + loop.run_forever() + asyncio.set_event_loop(None) + loop.close() + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -1219,7 +1247,7 @@ async def __call__( func_task = asyncio.create_task( to_thread( - lambda: asyncio.run( + lambda: _run_with_graceful_shutdown( _cancellable_run( cancel_event, func(cancel_event, call_timer, *inputs), diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py new file mode 100644 index 000000000..279f4a184 --- /dev/null +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -0,0 +1,296 @@ +"""Tests for async UDF event loop graceful shutdown.""" +import asyncio +import contextvars +import threading +import time +import unittest +from typing import Any +from typing import List + +from ..functions.ext.asgi import _cancellable_run +from ..functions.ext.asgi import _run_with_graceful_shutdown +from ..functions.ext.asgi import to_thread + + +class TestRunWithGracefulShutdown(unittest.TestCase): + """Test _run_with_graceful_shutdown handles loop cleanup properly.""" + + def test_basic_coroutine(self) -> None: + async def simple() -> int: + return 42 + + result = _run_with_graceful_shutdown(simple()) + self.assertEqual(result, 42) + + def test_callbacks_drained_before_close(self) -> None: + """Simulate httpx/anyio scheduling call_soon during teardown. + + This is the exact pattern that causes 'Event loop is closed' with + asyncio.run() -- a library schedules a callback in its __del__ or + aclose() that fires after the loop is closed. + """ + callback_executed: List[bool] = [] + + async def coroutine_with_cleanup_callback() -> str: + loop = asyncio.get_running_loop() + loop.call_soon(lambda: callback_executed.append(True)) + return 'done' + + result = _run_with_graceful_shutdown(coroutine_with_cleanup_callback()) + self.assertEqual(result, 'done') + self.assertEqual(callback_executed, [True]) + + def test_no_event_loop_closed_error(self) -> None: + """Verify no RuntimeError when cleanup schedules on the loop.""" + errors: List[RuntimeError] = [] + + async def simulate_httpx_teardown() -> str: + loop = asyncio.get_running_loop() + + def deferred_cleanup() -> None: + try: + loop.call_soon(lambda: None) + except RuntimeError as e: + errors.append(e) + + loop.call_soon(deferred_cleanup) + return 'ok' + + result = _run_with_graceful_shutdown(simulate_httpx_teardown()) + self.assertEqual(result, 'ok') + self.assertEqual(errors, []) + + def test_exception_propagates(self) -> None: + async def failing() -> None: + raise ValueError('test error') + + with self.assertRaises(ValueError) as ctx: + _run_with_graceful_shutdown(failing()) + self.assertEqual(str(ctx.exception), 'test error') + + def test_callbacks_drained_even_on_exception(self) -> None: + """Cleanup callbacks still run even if coroutine raises.""" + callback_executed: List[bool] = [] + + async def failing_with_callback() -> None: + loop = asyncio.get_running_loop() + loop.call_soon(lambda: callback_executed.append(True)) + raise ValueError('boom') + + with self.assertRaises(ValueError): + _run_with_graceful_shutdown(failing_with_callback()) + self.assertEqual(callback_executed, [True]) + + def test_pending_tasks_cancelled(self) -> None: + """Background tasks are cancelled during shutdown.""" + async def background() -> None: + await asyncio.sleep(999) + + async def main_with_background_task() -> str: + asyncio.create_task(background()) + return 'done' + + result = _run_with_graceful_shutdown(main_with_background_task()) + self.assertEqual(result, 'done') + + def test_isolation_between_calls(self) -> None: + """Each call gets its own event loop that is closed after use.""" + loops: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + loops.append(asyncio.get_running_loop()) + return True + + _run_with_graceful_shutdown(capture_loop()) + first_loop = loops[0] + self.assertTrue(first_loop.is_closed()) + + _run_with_graceful_shutdown(capture_loop()) + second_loop = loops[1] + self.assertTrue(second_loop.is_closed()) + + def test_cancellable_run_integration(self) -> None: + """Verify _cancellable_run works inside _run_with_graceful_shutdown.""" + cancel_event = threading.Event() + + async def slow_func() -> str: + return 'completed' + + result = _run_with_graceful_shutdown( + _cancellable_run(cancel_event, slow_func()), + ) + self.assertEqual(result, 'completed') + + def test_cancellation_via_event(self) -> None: + """Verify cancellation propagates through the full stack.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked_func() -> str: + await asyncio.sleep(999) + return 'should not reach' + + with self.assertRaises(asyncio.CancelledError): + _run_with_graceful_shutdown( + _cancellable_run(cancel_event, blocked_func()), + ) + + +class TestUDFDispatchEdgeCases(unittest.TestCase): + """Test edge cases in the UDF dispatch stack.""" + + def test_timeout_cancels_running_function(self) -> None: + """Cancel event set from timer thread cancels a blocked coroutine.""" + cancel_event = threading.Event() + + async def long_running() -> str: + await asyncio.sleep(999) + return 'should not reach' + + def set_cancel_after_delay() -> None: + time.sleep(0.2) + cancel_event.set() + + timer = threading.Thread(target=set_cancel_after_delay) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_with_graceful_shutdown( + _cancellable_run(cancel_event, long_running()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.2s delay + up to 0.1s poll interval + margin + self.assertLess(elapsed, 0.5) + + def test_exception_propagates_through_full_stack(self) -> None: + """User exception propagates unwrapped through the entire dispatch.""" + cancel_event = threading.Event() + + class CustomUDFError(Exception): + pass + + async def failing_udf() -> None: + raise CustomUDFError('embedding service unavailable') + + with self.assertRaises(CustomUDFError) as ctx: + _run_with_graceful_shutdown( + _cancellable_run(cancel_event, failing_udf()), + ) + self.assertEqual(str(ctx.exception), 'embedding service unavailable') + + def test_cancel_event_detected_within_poll_interval(self) -> None: + """Cancellation is detected within one poll cycle (0.1s).""" + cancel_event = threading.Event() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + def set_cancel() -> None: + time.sleep(0.05) + cancel_event.set() + + timer = threading.Thread(target=set_cancel) + timer.start() + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + _run_with_graceful_shutdown( + _cancellable_run(cancel_event, blocked()), + ) + elapsed = time.monotonic() - start + timer.join() + # 0.05s delay + 0.1s poll interval + margin + self.assertLess(elapsed, 0.25) + + def test_context_vars_propagate_through_to_thread(self) -> None: + """Context variables are visible inside to_thread executor.""" + test_var: contextvars.ContextVar[str] = contextvars.ContextVar( + 'test_var', + ) + test_var.set('hello_from_parent') + captured: List[str] = [] + + def read_context_var() -> str: + val = test_var.get('NOT_FOUND') + captured.append(val) + return val + + async def run_in_thread() -> str: + return await to_thread(read_context_var) + + result = _run_with_graceful_shutdown(run_in_thread()) + self.assertEqual(result, 'hello_from_parent') + self.assertEqual(captured, ['hello_from_parent']) + + def test_concurrent_requests_isolated(self) -> None: + """Parallel executions don't share state.""" + results: List[Any] = [None, None, None] + + def run_isolated(index: int) -> None: + async def compute() -> int: + await asyncio.sleep(0.05) + return index * 10 + + results[index] = _run_with_graceful_shutdown(compute()) + + threads = [ + threading.Thread(target=run_isolated, args=(i,)) + for i in range(3) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(results, [0, 10, 20]) + + def test_sync_function_through_async_wrapper(self) -> None: + """Synchronous function works when wrapped as async coroutine.""" + cancel_event = threading.Event() + + async def sync_as_async() -> int: + # Simulates what decorator.py's async_wrapper does for sync UDFs + return 42 + 1 + + result = _run_with_graceful_shutdown( + _cancellable_run(cancel_event, sync_as_async()), + ) + self.assertEqual(result, 43) + + def test_cancel_event_not_set_on_success(self) -> None: + """Cancel event remains unset after successful execution.""" + cancel_event = threading.Event() + + async def quick() -> str: + return 'fast' + + result = _run_with_graceful_shutdown( + _cancellable_run(cancel_event, quick()), + ) + self.assertEqual(result, 'fast') + self.assertFalse(cancel_event.is_set()) + + def test_callbacks_from_cancelled_tasks_still_drain(self) -> None: + """Background task callbacks drain even when task is cancelled.""" + drained: List[bool] = [] + + async def bg_with_callback() -> None: + loop = asyncio.get_running_loop() + loop.call_soon(lambda: drained.append(True)) + await asyncio.sleep(999) + + async def main() -> str: + asyncio.create_task(bg_with_callback()) + await asyncio.sleep(0.05) # Let background task start + return 'done' + + result = _run_with_graceful_shutdown(main()) + self.assertEqual(result, 'done') + self.assertEqual(drained, [True]) + + +if __name__ == '__main__': + unittest.main() From c96113b792f99435ab9d92e625cefedea1bc0212 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Wed, 3 Jun 2026 23:50:15 +0530 Subject: [PATCH 12/19] reuse event loop across requests on same thread Co-authored-by: Cursor --- singlestoredb/functions/ext/asgi.py | 86 +++++-- singlestoredb/tests/test_udf_event_loop.py | 286 ++++++++++----------- 2 files changed, 203 insertions(+), 169 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 329cae64f..294b17c21 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,6 +24,7 @@ """ import argparse import asyncio +import atexit import contextvars import dataclasses import datetime @@ -135,32 +136,67 @@ async def _cancellable_run( return task.result() -def _run_with_graceful_shutdown(coro: Any) -> Any: - """Run a coroutine in a new event loop, draining callbacks before close. +# Each `to_thread` worker thread owns a long-lived event loop reused across +# requests, so loop-bound resources (HTTP pools, DB sessions, sockets) can +# survive between calls handled by the same thread. +_thread_local = threading.local() +_loop_registry: 'Set[asyncio.AbstractEventLoop]' = set() +_loop_registry_lock = threading.Lock() - Unlike asyncio.run(), this prevents 'Event loop is closed' errors from - libraries (httpx/anyio) that schedule cleanup callbacks during teardown. - """ - loop = asyncio.new_event_loop() - try: + +def _get_thread_loop() -> asyncio.AbstractEventLoop: + """Return (creating if needed) the calling thread's persistent loop.""" + loop = getattr(_thread_local, 'loop', None) + if loop is None or loop.is_closed(): + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - return loop.run_until_complete(coro) - finally: + _thread_local.loop = loop + with _loop_registry_lock: + _loop_registry.add(loop) + return loop + + +def _run_on_thread_loop(coro: Any) -> Any: + """ + Run ``coro`` on the calling thread's persistent loop. + + The loop is never closed between calls, so loop-bound resources (e.g. + httpx keep-alive pools) survive across requests and the deferred + "Event loop is closed" errors thrown by httpx/anyio at teardown do not + occur. + + Caveat: tasks the user code spawns via ``asyncio.create_task`` and + leaves running outlive the current call too. That is the price of + keeping shared resources alive; ``cancel_event`` does not reach them. + """ + loop = _get_thread_loop() + return loop.run_until_complete(coro) + + +def _shutdown_thread_loops() -> None: + """Best-effort cleanup of all persistent worker-thread loops at exit.""" + with _loop_registry_lock: + loops = list(_loop_registry) + _loop_registry.clear() + + for loop in loops: + if loop.is_closed(): + continue try: - pending = asyncio.all_tasks(loop) - if pending: - for task in pending: - task.cancel() - loop.run_until_complete( - asyncio.gather(*pending, return_exceptions=True), - ) + # Owning thread is no longer running the loop; safe to drive + # teardown from this (exiting) thread. loop.run_until_complete(loop.shutdown_asyncgens()) loop.run_until_complete(loop.shutdown_default_executor()) + except Exception: + pass finally: - loop.call_soon(loop.stop) - loop.run_forever() - asyncio.set_event_loop(None) - loop.close() + try: + loop.close() + except Exception: + pass + + +atexit.register(_shutdown_thread_loops) # Use negative values to indicate unsigned ints / binary data / usec time precision @@ -1247,7 +1283,7 @@ async def __call__( func_task = asyncio.create_task( to_thread( - lambda: _run_with_graceful_shutdown( + lambda: _run_on_thread_loop( _cancellable_run( cancel_event, func(cancel_event, call_timer, *inputs), @@ -1270,19 +1306,21 @@ async def __call__( all_tasks, return_when=asyncio.FIRST_COMPLETED, ) - await cancel_all_tasks(pending) + # Signal the worker before awaiting cancellation: cancelling + # func_task only flips its asyncio wrapper, not the executor + # work; only cancel_event reaches the worker loop. if func_task in pending: cancel_event.set() + await cancel_all_tasks(pending) + for task in done: if task is disconnect_task: - cancel_event.set() raise asyncio.CancelledError( 'Function call was cancelled by client disconnect', ) elif task is timeout_task: - cancel_event.set() raise asyncio.TimeoutError( 'Function call was cancelled due to timeout', ) diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py index 279f4a184..ace452112 100644 --- a/singlestoredb/tests/test_udf_event_loop.py +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -1,4 +1,4 @@ -"""Tests for async UDF event loop graceful shutdown.""" +"""Tests for the async UDF persistent per-thread event loop.""" import asyncio import contextvars import threading @@ -8,134 +8,11 @@ from typing import List from ..functions.ext.asgi import _cancellable_run -from ..functions.ext.asgi import _run_with_graceful_shutdown +from ..functions.ext.asgi import _get_thread_loop +from ..functions.ext.asgi import _run_on_thread_loop from ..functions.ext.asgi import to_thread -class TestRunWithGracefulShutdown(unittest.TestCase): - """Test _run_with_graceful_shutdown handles loop cleanup properly.""" - - def test_basic_coroutine(self) -> None: - async def simple() -> int: - return 42 - - result = _run_with_graceful_shutdown(simple()) - self.assertEqual(result, 42) - - def test_callbacks_drained_before_close(self) -> None: - """Simulate httpx/anyio scheduling call_soon during teardown. - - This is the exact pattern that causes 'Event loop is closed' with - asyncio.run() -- a library schedules a callback in its __del__ or - aclose() that fires after the loop is closed. - """ - callback_executed: List[bool] = [] - - async def coroutine_with_cleanup_callback() -> str: - loop = asyncio.get_running_loop() - loop.call_soon(lambda: callback_executed.append(True)) - return 'done' - - result = _run_with_graceful_shutdown(coroutine_with_cleanup_callback()) - self.assertEqual(result, 'done') - self.assertEqual(callback_executed, [True]) - - def test_no_event_loop_closed_error(self) -> None: - """Verify no RuntimeError when cleanup schedules on the loop.""" - errors: List[RuntimeError] = [] - - async def simulate_httpx_teardown() -> str: - loop = asyncio.get_running_loop() - - def deferred_cleanup() -> None: - try: - loop.call_soon(lambda: None) - except RuntimeError as e: - errors.append(e) - - loop.call_soon(deferred_cleanup) - return 'ok' - - result = _run_with_graceful_shutdown(simulate_httpx_teardown()) - self.assertEqual(result, 'ok') - self.assertEqual(errors, []) - - def test_exception_propagates(self) -> None: - async def failing() -> None: - raise ValueError('test error') - - with self.assertRaises(ValueError) as ctx: - _run_with_graceful_shutdown(failing()) - self.assertEqual(str(ctx.exception), 'test error') - - def test_callbacks_drained_even_on_exception(self) -> None: - """Cleanup callbacks still run even if coroutine raises.""" - callback_executed: List[bool] = [] - - async def failing_with_callback() -> None: - loop = asyncio.get_running_loop() - loop.call_soon(lambda: callback_executed.append(True)) - raise ValueError('boom') - - with self.assertRaises(ValueError): - _run_with_graceful_shutdown(failing_with_callback()) - self.assertEqual(callback_executed, [True]) - - def test_pending_tasks_cancelled(self) -> None: - """Background tasks are cancelled during shutdown.""" - async def background() -> None: - await asyncio.sleep(999) - - async def main_with_background_task() -> str: - asyncio.create_task(background()) - return 'done' - - result = _run_with_graceful_shutdown(main_with_background_task()) - self.assertEqual(result, 'done') - - def test_isolation_between_calls(self) -> None: - """Each call gets its own event loop that is closed after use.""" - loops: List[asyncio.AbstractEventLoop] = [] - - async def capture_loop() -> bool: - loops.append(asyncio.get_running_loop()) - return True - - _run_with_graceful_shutdown(capture_loop()) - first_loop = loops[0] - self.assertTrue(first_loop.is_closed()) - - _run_with_graceful_shutdown(capture_loop()) - second_loop = loops[1] - self.assertTrue(second_loop.is_closed()) - - def test_cancellable_run_integration(self) -> None: - """Verify _cancellable_run works inside _run_with_graceful_shutdown.""" - cancel_event = threading.Event() - - async def slow_func() -> str: - return 'completed' - - result = _run_with_graceful_shutdown( - _cancellable_run(cancel_event, slow_func()), - ) - self.assertEqual(result, 'completed') - - def test_cancellation_via_event(self) -> None: - """Verify cancellation propagates through the full stack.""" - cancel_event = threading.Event() - cancel_event.set() - - async def blocked_func() -> str: - await asyncio.sleep(999) - return 'should not reach' - - with self.assertRaises(asyncio.CancelledError): - _run_with_graceful_shutdown( - _cancellable_run(cancel_event, blocked_func()), - ) - - class TestUDFDispatchEdgeCases(unittest.TestCase): """Test edge cases in the UDF dispatch stack.""" @@ -156,7 +33,7 @@ def set_cancel_after_delay() -> None: start = time.monotonic() with self.assertRaises(asyncio.CancelledError): - _run_with_graceful_shutdown( + _run_on_thread_loop( _cancellable_run(cancel_event, long_running()), ) elapsed = time.monotonic() - start @@ -175,7 +52,7 @@ async def failing_udf() -> None: raise CustomUDFError('embedding service unavailable') with self.assertRaises(CustomUDFError) as ctx: - _run_with_graceful_shutdown( + _run_on_thread_loop( _cancellable_run(cancel_event, failing_udf()), ) self.assertEqual(str(ctx.exception), 'embedding service unavailable') @@ -197,7 +74,7 @@ def set_cancel() -> None: start = time.monotonic() with self.assertRaises(asyncio.CancelledError): - _run_with_graceful_shutdown( + _run_on_thread_loop( _cancellable_run(cancel_event, blocked()), ) elapsed = time.monotonic() - start @@ -221,7 +98,7 @@ def read_context_var() -> str: async def run_in_thread() -> str: return await to_thread(read_context_var) - result = _run_with_graceful_shutdown(run_in_thread()) + result = _run_on_thread_loop(run_in_thread()) self.assertEqual(result, 'hello_from_parent') self.assertEqual(captured, ['hello_from_parent']) @@ -234,7 +111,7 @@ async def compute() -> int: await asyncio.sleep(0.05) return index * 10 - results[index] = _run_with_graceful_shutdown(compute()) + results[index] = _run_on_thread_loop(compute()) threads = [ threading.Thread(target=run_isolated, args=(i,)) @@ -255,7 +132,7 @@ async def sync_as_async() -> int: # Simulates what decorator.py's async_wrapper does for sync UDFs return 42 + 1 - result = _run_with_graceful_shutdown( + result = _run_on_thread_loop( _cancellable_run(cancel_event, sync_as_async()), ) self.assertEqual(result, 43) @@ -267,29 +144,148 @@ def test_cancel_event_not_set_on_success(self) -> None: async def quick() -> str: return 'fast' - result = _run_with_graceful_shutdown( + result = _run_on_thread_loop( _cancellable_run(cancel_event, quick()), ) self.assertEqual(result, 'fast') self.assertFalse(cancel_event.is_set()) - def test_callbacks_from_cancelled_tasks_still_drain(self) -> None: - """Background task callbacks drain even when task is cancelled.""" - drained: List[bool] = [] - async def bg_with_callback() -> None: +class TestRunOnThreadLoop(unittest.TestCase): + """Test _run_on_thread_loop reuses a persistent per-thread event loop.""" + + def test_basic_coroutine(self) -> None: + async def simple() -> int: + return 42 + + self.assertEqual(_run_on_thread_loop(simple()), 42) + + def test_loop_reused_across_calls(self) -> None: + """The same loop object is reused for successive calls in a thread.""" + loops: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + _run_on_thread_loop(capture_loop()) + + self.assertIs(loops[0], loops[1]) + + def test_loop_not_closed_between_calls(self) -> None: + """The persistent loop stays open so resources survive requests.""" + captured: List[asyncio.AbstractEventLoop] = [] + + async def capture_loop() -> bool: + captured.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture_loop()) + loop = captured[0] + self.assertFalse(loop.is_closed()) + + # Still usable for the next request. + _run_on_thread_loop(capture_loop()) + self.assertFalse(loop.is_closed()) + + def test_async_resource_survives_between_calls(self) -> None: + """An object bound to the loop can be reused on the next call. + + This mirrors caching e.g. an httpx.AsyncClient keyed by the loop and + reusing its connection pool on subsequent requests. + """ + clients: dict = {} + + async def get_or_create_client() -> int: loop = asyncio.get_running_loop() - loop.call_soon(lambda: drained.append(True)) + if loop not in clients: + clients[loop] = object() + return id(clients[loop]) + + first = _run_on_thread_loop(get_or_create_client()) + second = _run_on_thread_loop(get_or_create_client()) + + self.assertEqual(first, second) + self.assertEqual(len(clients), 1) + + def test_separate_threads_get_separate_loops(self) -> None: + """Each worker thread owns its own persistent loop.""" + loops: List[asyncio.AbstractEventLoop] = [] + lock = threading.Lock() + + def run_in_thread() -> None: + async def capture() -> bool: + with lock: + loops.append(asyncio.get_running_loop()) + return True + + _run_on_thread_loop(capture()) + + threads = [threading.Thread(target=run_in_thread) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(loops), 3) + self.assertEqual(len({id(loop) for loop in loops}), 3) + + def test_get_thread_loop_idempotent(self) -> None: + """_get_thread_loop returns the same loop on repeated calls.""" + def run_in_thread(out: List[asyncio.AbstractEventLoop]) -> None: + out.append(_get_thread_loop()) + out.append(_get_thread_loop()) + + out: List[asyncio.AbstractEventLoop] = [] + t = threading.Thread(target=run_in_thread, args=(out,)) + t.start() + t.join() + + self.assertIs(out[0], out[1]) + + def test_exception_propagates(self) -> None: + async def failing() -> None: + raise ValueError('test error') + + with self.assertRaises(ValueError) as ctx: + _run_on_thread_loop(failing()) + self.assertEqual(str(ctx.exception), 'test error') + + def test_cancellable_run_integration(self) -> None: + """_cancellable_run works on the persistent loop.""" + cancel_event = threading.Event() + + async def slow_func() -> str: + return 'completed' + + result = _run_on_thread_loop( + _cancellable_run(cancel_event, slow_func()), + ) + self.assertEqual(result, 'completed') + + def test_cancellation_via_event(self) -> None: + """Cancellation propagates through the persistent-loop stack.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked_func() -> str: await asyncio.sleep(999) + return 'should not reach' + + with self.assertRaises(asyncio.CancelledError): + _run_on_thread_loop( + _cancellable_run(cancel_event, blocked_func()), + ) - async def main() -> str: - asyncio.create_task(bg_with_callback()) - await asyncio.sleep(0.05) # Let background task start - return 'done' + # Loop must remain usable after a cancelled request. + async def quick() -> str: + return 'ok' - result = _run_with_graceful_shutdown(main()) - self.assertEqual(result, 'done') - self.assertEqual(drained, [True]) + self.assertEqual( + _run_on_thread_loop(_cancellable_run(threading.Event(), quick())), + 'ok', + ) if __name__ == '__main__': From 952ef2dea33bcc68443e04222a5ffc247d16cd15 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Wed, 10 Jun 2026 09:32:17 +0530 Subject: [PATCH 13/19] dedicated async thread --- pyproject.toml | 2 +- singlestoredb/functions/ext/asgi.py | 162 +++++- singlestoredb/tests/test_udf_event_loop.py | 543 +++++++++++++++++++++ 3 files changed, 701 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb436d8ea..0f57263e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "singlestoredb" -version = "1.16.10" +version = "1.16.11rc1+byasaini" description = "Interface to the SingleStoreDB database and workspace management APIs" readme = {file = "README.md", content-type = "text/markdown"} license = {text = "Apache-2.0"} diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 294b17c21..e3967cc7a 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -139,6 +139,10 @@ async def _cancellable_run( # Each `to_thread` worker thread owns a long-lived event loop reused across # requests, so loop-bound resources (HTTP pools, DB sessions, sockets) can # survive between calls handled by the same thread. +# +# This per-thread loop is only used for SYNC user UDFs: a sync UDF blocks +# its worker thread for the duration of the call, so giving each worker +# thread its own loop avoids cross-thread loop sharing for those calls. _thread_local = threading.local() _loop_registry: 'Set[asyncio.AbstractEventLoop]' = set() _loop_registry_lock = threading.Lock() @@ -199,6 +203,137 @@ def _shutdown_thread_loops() -> None: atexit.register(_shutdown_thread_loops) +# Dedicated event loop used for ALL async UDF requests. +# +# Async UDFs commonly create resources that are bound to the event loop +# they are first used on (httpx connection pools, async DB clients, anyio +# streams, ...). Dispatching async requests to ad-hoc worker threads — +# each with its own loop — produced " is bound +# to a different event loop" errors when those cached resources were +# reused by a request that landed on a different worker thread. +# +# Routing every async UDF onto a single dedicated loop fixes that, and +# also gives true concurrency across requests: ``run_coroutine_threadsafe`` +# schedules each new coroutine immediately so that incoming requests do +# not queue behind in-flight ones. +# +# Sync UDFs intentionally still go through the worker-thread / per-thread +# loop path above: a sync UDF would block this dedicated loop and starve +# every other in-flight async request. +_async_dispatch_loop: 'Optional[asyncio.AbstractEventLoop]' = None +_async_dispatch_thread: 'Optional[threading.Thread]' = None +_async_dispatch_lock = threading.Lock() + + +def _get_async_dispatch_loop() -> asyncio.AbstractEventLoop: + """ + Return (lazily creating) the singleton async-dispatch event loop. + + The loop is owned by a dedicated daemon thread that runs ``run_forever`` + for the lifetime of the process. All async UDF coroutines are scheduled + on this loop so that loop-bound resources can be safely reused across + requests. + """ + global _async_dispatch_loop, _async_dispatch_thread + + loop = _async_dispatch_loop + if loop is not None and not loop.is_closed(): + return loop + + with _async_dispatch_lock: + if _async_dispatch_loop is not None and \ + not _async_dispatch_loop.is_closed(): + return _async_dispatch_loop + + ready = threading.Event() + captured: List[asyncio.AbstractEventLoop] = [] + + def run_loop() -> None: + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + captured.append(new_loop) + ready.set() + try: + new_loop.run_forever() + finally: + try: + new_loop.run_until_complete(new_loop.shutdown_asyncgens()) + except Exception: + pass + try: + new_loop.run_until_complete( + new_loop.shutdown_default_executor(), + ) + except Exception: + pass + try: + new_loop.close() + except Exception: + pass + + thread = threading.Thread( + target=run_loop, + name='singlestoredb-udf-async-dispatch', + daemon=True, + ) + thread.start() + ready.wait() + + _async_dispatch_loop = captured[0] + _async_dispatch_thread = thread + return _async_dispatch_loop + + +def _get_async_dispatch_thread() -> 'Optional[threading.Thread]': + """Return the dedicated dispatch thread (or ``None`` if not started).""" + return _async_dispatch_thread + + +async def _dispatch_to_async_loop(coro: Any) -> Any: + """ + Schedule ``coro`` on the dedicated async-dispatch loop and await its result. + + The coroutine begins running immediately on the dispatch loop — it does + NOT wait for any earlier in-flight async UDF to complete — so concurrent + async requests run in parallel on a single shared loop. + + Cancellation of the awaiting task on the caller's loop is propagated + best-effort to the work scheduled on the dispatch loop. The user code's + ``cancel_event`` (set by the request handler on timeout / disconnect) + remains the authoritative cancellation signal because it is observed by + ``_cancellable_run`` from inside the dispatch loop. + """ + loop = _get_async_dispatch_loop() + cf = asyncio.run_coroutine_threadsafe(coro, loop) + try: + return await asyncio.wrap_future(cf) + except asyncio.CancelledError: + cf.cancel() + raise + + +def _shutdown_async_dispatch_loop() -> None: + """Best-effort cleanup of the dedicated async-dispatch loop at exit.""" + global _async_dispatch_loop, _async_dispatch_thread + with _async_dispatch_lock: + loop = _async_dispatch_loop + thread = _async_dispatch_thread + _async_dispatch_loop = None + _async_dispatch_thread = None + + if loop is not None and not loop.is_closed(): + try: + loop.call_soon_threadsafe(loop.stop) + except Exception: + pass + + if thread is not None: + thread.join(timeout=5) + + +atexit.register(_shutdown_async_dispatch_loop) + + # Use negative values to indicate unsigned ints / binary data / usec time precision rowdat_1_type_map = { 'bool': ft.LONGLONG, @@ -1281,16 +1416,33 @@ async def __call__( func_info['colspec'], b''.join(data), ) - func_task = asyncio.create_task( - to_thread( - lambda: _run_on_thread_loop( + # Async user UDFs share a single dedicated event-loop thread + # so that loop-bound resources (httpx pools, async clients, + # ...) can be reused across requests; new requests are + # scheduled immediately and run concurrently on that loop. + # Sync user UDFs continue to use the worker-thread pool (one + # persistent loop per thread) because a sync call would + # block the shared dispatch loop and starve other requests. + if func_info.get('is_async'): + func_task = asyncio.create_task( + _dispatch_to_async_loop( _cancellable_run( cancel_event, func(cancel_event, call_timer, *inputs), ), ), - ), - ) + ) + else: + func_task = asyncio.create_task( + to_thread( + lambda: _run_on_thread_loop( + _cancellable_run( + cancel_event, + func(cancel_event, call_timer, *inputs), + ), + ), + ), + ) disconnect_task = asyncio.create_task( asyncio.sleep(int(1e9)) if ignore_cancel else cancel_on_disconnect(receive), diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py index ace452112..7c89c397a 100644 --- a/singlestoredb/tests/test_udf_event_loop.py +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -1,15 +1,24 @@ """Tests for the async UDF persistent per-thread event loop.""" import asyncio import contextvars +import json as jsonlib import threading import time import unittest from typing import Any +from typing import Dict from typing import List +from typing import Set +from typing import Tuple +from ..functions import udf from ..functions.ext.asgi import _cancellable_run +from ..functions.ext.asgi import _dispatch_to_async_loop +from ..functions.ext.asgi import _get_async_dispatch_loop +from ..functions.ext.asgi import _get_async_dispatch_thread from ..functions.ext.asgi import _get_thread_loop from ..functions.ext.asgi import _run_on_thread_loop +from ..functions.ext.asgi import Application from ..functions.ext.asgi import to_thread @@ -288,5 +297,539 @@ async def quick() -> str: ) +class TestAsyncDispatchLoop(unittest.TestCase): + """All async UDF dispatches share a single dedicated event-loop thread. + + The dispatch loop is process-global and lazily started; resources bound + to it (HTTP pools, async clients, connection caches) are reused across + every async UDF request. New requests are scheduled immediately and run + concurrently on that loop instead of being serialized behind earlier + in-flight requests. + """ + + def test_dispatch_loop_is_single_dedicated_thread(self) -> None: + """All dispatches run on the same dedicated thread (not the caller).""" + seen_threads: Set[int] = set() + + async def capture() -> int: + seen_threads.add(threading.get_ident()) + return 1 + + async def run_many() -> None: + await asyncio.gather(*[ + _dispatch_to_async_loop(capture()) for _ in range(8) + ]) + + caller_thread = threading.get_ident() + asyncio.run(run_many()) + + self.assertEqual(len(seen_threads), 1) + self.assertNotIn(caller_thread, seen_threads) + # The thread we observed is the singleton dispatch thread. + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + self.assertEqual(seen_threads.pop(), dispatch_thread.ident) + + def test_dispatch_loop_is_single_event_loop(self) -> None: + """All dispatches run on the SAME event loop instance.""" + captured: List[asyncio.AbstractEventLoop] = [] + + async def capture() -> int: + captured.append(asyncio.get_running_loop()) + return 1 + + async def run_many() -> None: + await asyncio.gather(*[ + _dispatch_to_async_loop(capture()) for _ in range(5) + ]) + + asyncio.run(run_many()) + + self.assertEqual(len(captured), 5) + first = captured[0] + for loop in captured: + self.assertIs(loop, first) + self.assertIs(first, _get_async_dispatch_loop()) + + def test_concurrent_dispatches_do_not_serialize(self) -> None: + """Slow dispatches run in parallel on the loop; new requests do not + wait for earlier ones to finish.""" + n = 6 + per_call_sleep = 0.3 + + async def slow() -> str: + await asyncio.sleep(per_call_sleep) + return 'done' + + async def run_many() -> List[str]: + return await asyncio.gather(*[ + _dispatch_to_async_loop(slow()) for _ in range(n) + ]) + + start = time.monotonic() + results = asyncio.run(run_many()) + elapsed = time.monotonic() - start + + self.assertEqual(results, ['done'] * n) + # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. + # Allow generous margin for CI noise. + self.assertLess(elapsed, per_call_sleep * 2) + + def test_new_request_does_not_wait_for_in_flight_request(self) -> None: + """A new async request is submitted to the dispatch thread + immediately and runs while an earlier request is still in-flight. + + This is the explicit guarantee that async UDF dispatch is not + serialized: a request fired AFTER another long one has started + must (a) start before the long one finishes, (b) finish before + the long one finishes, and (c) be submitted with negligible + latency from the caller's perspective. + """ + long_sleep = 1.0 + ts: Dict[str, float] = {} + # Created lazily on the dispatch loop so the asyncio.Event is bound + # to the correct loop. + signals: Dict[str, asyncio.Event] = {} + + async def long_running() -> str: + ts['long_started'] = time.monotonic() + signals['started'] = asyncio.Event() + signals['started'].set() + await asyncio.sleep(long_sleep) + ts['long_finished'] = time.monotonic() + return 'long' + + async def quick() -> str: + ts['quick_started'] = time.monotonic() + await asyncio.sleep(0) + ts['quick_finished'] = time.monotonic() + return 'quick' + + async def driver() -> None: + long_task = asyncio.create_task( + _dispatch_to_async_loop(long_running()), + ) + # Wait until the long task has actually started on the + # dispatch loop. Only after this point can we be sure the + # next dispatch is "during" an in-flight request. + for _ in range(100): + await asyncio.sleep(0.01) + if 'started' in signals and signals['started'].is_set(): + break + self.assertIn('long_started', ts) + + ts['quick_dispatch_called'] = time.monotonic() + quick_result = await _dispatch_to_async_loop(quick()) + ts['quick_dispatch_returned'] = time.monotonic() + self.assertEqual(quick_result, 'quick') + + long_result = await long_task + self.assertEqual(long_result, 'long') + + asyncio.run(driver()) + + # The new request actually overlapped the in-flight one. + self.assertGreater(ts['quick_started'], ts['long_started']) + self.assertLess(ts['quick_started'], ts['long_finished']) + self.assertLess(ts['quick_finished'], ts['long_finished']) + + # Submission of the new request to the dispatch thread is + # non-blocking: the awaiter returned in well under the long + # request's remaining time. + dispatch_latency = ts['quick_dispatch_returned'] \ + - ts['quick_dispatch_called'] + self.assertLess(dispatch_latency, long_sleep / 2) + + def test_many_new_requests_run_during_one_in_flight_request(self) -> None: + """Many new async requests, each fired sequentially while a single + long-running request is in-flight, all start AND finish before the + long one finishes.""" + long_sleep = 1.0 + n_quick = 8 + ts: Dict[str, float] = {} + quick_finished: List[float] = [] + signals: Dict[str, asyncio.Event] = {} + + async def long_running() -> str: + ts['long_started'] = time.monotonic() + signals['started'] = asyncio.Event() + signals['started'].set() + await asyncio.sleep(long_sleep) + ts['long_finished'] = time.monotonic() + return 'long' + + async def quick(i: int) -> int: + await asyncio.sleep(0.01) + quick_finished.append(time.monotonic()) + return i + + async def driver() -> None: + long_task = asyncio.create_task( + _dispatch_to_async_loop(long_running()), + ) + # Wait for the long task to start. + for _ in range(100): + await asyncio.sleep(0.01) + if 'started' in signals and signals['started'].is_set(): + break + + results = await asyncio.gather(*[ + _dispatch_to_async_loop(quick(i)) for i in range(n_quick) + ]) + self.assertEqual(results, list(range(n_quick))) + await long_task + + asyncio.run(driver()) + + # All quick requests finished before the long one did, proving + # they were not queued behind it. + self.assertEqual(len(quick_finished), n_quick) + for finish in quick_finished: + self.assertLess(finish, ts['long_finished']) + self.assertGreater(finish, ts['long_started']) + + def test_loop_bound_resource_reused_across_dispatches(self) -> None: + """A resource keyed by id(loop) is shared by every async request, + even across separate caller event loops (separate parent runs).""" + cache: Dict[int, object] = {} + + async def acquire() -> int: + loop = asyncio.get_running_loop() + key = id(loop) + if key not in cache: + cache[key] = object() + return id(cache[key]) + + async def run_one() -> int: + return await _dispatch_to_async_loop(acquire()) + + first = asyncio.run(run_one()) + second = asyncio.run(run_one()) + third = asyncio.run(run_one()) + + self.assertEqual(first, second) + self.assertEqual(second, third) + self.assertEqual(len(cache), 1) + + def test_dispatch_propagates_exception(self) -> None: + """Exceptions from the dispatched coroutine surface to the caller.""" + class DispatchedError(Exception): + pass + + async def failing() -> None: + raise DispatchedError('boom') + + async def driver() -> None: + await _dispatch_to_async_loop(failing()) + + with self.assertRaises(DispatchedError) as ctx: + asyncio.run(driver()) + self.assertEqual(str(ctx.exception), 'boom') + + def test_dispatch_with_cancel_event(self) -> None: + """`_cancellable_run` on the dispatch loop honors the cancel event.""" + cancel_event = threading.Event() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + def trip_cancel() -> None: + time.sleep(0.1) + cancel_event.set() + + timer = threading.Thread(target=trip_cancel) + timer.start() + + async def driver() -> None: + await _dispatch_to_async_loop( + _cancellable_run(cancel_event, blocked()), + ) + + start = time.monotonic() + with self.assertRaises(asyncio.CancelledError): + asyncio.run(driver()) + elapsed = time.monotonic() - start + timer.join() + # 0.1s delay + 0.1s poll interval + margin + self.assertLess(elapsed, 0.5) + + def test_dispatch_loop_survives_after_cancellation(self) -> None: + """The dispatch loop remains usable after a cancelled request.""" + cancel_event = threading.Event() + cancel_event.set() + + async def blocked() -> str: + await asyncio.sleep(999) + return 'unreachable' + + async def driver_cancel() -> None: + await _dispatch_to_async_loop( + _cancellable_run(cancel_event, blocked()), + ) + + with self.assertRaises(asyncio.CancelledError): + asyncio.run(driver_cancel()) + + async def quick() -> str: + return 'ok' + + async def driver_ok() -> str: + return await _dispatch_to_async_loop(quick()) + + self.assertEqual(asyncio.run(driver_ok()), 'ok') + + +# Module-level UDFs used by the Application integration tests below. They +# must be defined at module scope so the signature inspection helpers can +# resolve their type hints. + +# Records the thread that actually executes each UDF body, keyed by tag. +_dispatch_observation: Dict[str, int] = {} +_dispatch_observation_lock = threading.Lock() +# Per-tag start / finish timestamps, used by the "no waiting for in-flight" +# test below to assert overlap between concurrent requests. +_dispatch_started_at: Dict[str, float] = {} +_dispatch_finished_at: Dict[str, float] = {} + + +def _record(tag: str) -> None: + with _dispatch_observation_lock: + _dispatch_observation[tag] = threading.get_ident() + _dispatch_started_at[tag] = time.monotonic() + + +def _record_finish(tag: str) -> None: + with _dispatch_observation_lock: + _dispatch_finished_at[tag] = time.monotonic() + + +@udf +async def _async_record_udf(tag: str) -> int: + _record(tag) + await asyncio.sleep(0) + _record_finish(tag) + return len(tag) + + +@udf +async def _async_slow_udf(tag: str) -> int: + _record(tag) + await asyncio.sleep(0.4) + _record_finish(tag) + return len(tag) + + +@udf +async def _async_long_udf(tag: str) -> int: + """Long-running async UDF used to verify that newly arriving async + requests do not have to wait for it to finish.""" + _record(tag) + await asyncio.sleep(1.0) + _record_finish(tag) + return len(tag) + + +@udf +def _sync_record_udf(tag: str) -> int: + _record(tag) + _record_finish(tag) + return len(tag) + + +def _make_invoke_args( + name: str, + rows: List[Tuple[Any, ...]], +) -> Tuple[Dict[str, Any], Any, Any, List[Dict[str, Any]]]: + """Build a minimal ASGI scope/receive/send for an /invoke request.""" + payload = jsonlib.dumps({ + 'data': [[i, *row] for i, row in enumerate(rows)], + }).encode('utf-8') + + received: Dict[str, bool] = {'sent': False} + + async def receive() -> Dict[str, Any]: + if received['sent']: + await asyncio.sleep(60) + return {'type': 'http.disconnect'} + received['sent'] = True + return {'type': 'http.request', 'body': payload, 'more_body': False} + + sent: List[Dict[str, Any]] = [] + + async def send(msg: Dict[str, Any]) -> None: + sent.append(msg) + + scope = { + 'type': 'http', + 'method': 'POST', + 'path': '/invoke', + 'scheme': 'http', + 'headers': [ + (b'content-type', b'application/json'), + (b'accepts', b'application/json'), + (b's2-ef-name', name.encode('utf-8')), + (b's2-ef-version', b'1.0'), + (b's2-ef-ignore-cancel', b'true'), + ], + } + return scope, receive, send, sent + + +def _reset_dispatch_observation() -> None: + with _dispatch_observation_lock: + _dispatch_observation.clear() + _dispatch_started_at.clear() + _dispatch_finished_at.clear() + + +class TestApplicationDispatchRouting(unittest.TestCase): + """End-to-end: Application routes async UDFs to the dispatch loop and + sync UDFs to a worker thread, and concurrent async requests run in + parallel on the dispatch loop.""" + + def setUp(self) -> None: + _reset_dispatch_observation() + self.app = Application( + functions=[ + _async_record_udf, + _async_slow_udf, + _async_long_udf, + _sync_record_udf, + ], + disable_metrics=True, + ) + + @staticmethod + def _headers_dict(scope: Dict[str, Any]) -> Dict[bytes, bytes]: + return {k: v for k, v in scope['headers']} + + def _invoke(self, name: str, rows: List[Tuple[Any, ...]]) -> List[Dict[str, Any]]: + scope, receive, send, sent = _make_invoke_args(name, rows) + scope['headers'] = list(scope['headers']) + # Application reads headers as a dict via ``dict(scope['headers'])``, + # which works for our list of tuples. + asyncio.run(self.app(scope, receive, send)) + return sent + + async def _invoke_async( + self, name: str, rows: List[Tuple[Any, ...]], + ) -> List[Dict[str, Any]]: + scope, receive, send, sent = _make_invoke_args(name, rows) + scope['headers'] = list(scope['headers']) + await self.app(scope, receive, send) + return sent + + def test_async_udf_runs_on_dispatch_thread(self) -> None: + """An async UDF body executes on the dedicated dispatch thread.""" + sent = self._invoke('_async_record_udf', [('alpha',)]) + statuses = [m for m in sent if m.get('type') == 'http.response.start'] + self.assertTrue(statuses and statuses[0]['status'] == 200, sent) + + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + with _dispatch_observation_lock: + self.assertEqual(_dispatch_observation['alpha'], dispatch_thread.ident) + + def test_sync_udf_runs_on_a_worker_thread_not_dispatch(self) -> None: + """A sync UDF body runs on a worker thread, NOT the dispatch thread.""" + # Force the dispatch thread to exist so we can compare ids. + _get_async_dispatch_loop() + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + + sent = self._invoke('_sync_record_udf', [('beta',)]) + statuses = [m for m in sent if m.get('type') == 'http.response.start'] + self.assertTrue(statuses and statuses[0]['status'] == 200, sent) + + with _dispatch_observation_lock: + sync_thread = _dispatch_observation['beta'] + + self.assertNotEqual(sync_thread, threading.get_ident()) + self.assertNotEqual(sync_thread, dispatch_thread.ident) + + def test_concurrent_async_requests_share_dispatch_thread(self) -> None: + """Two concurrent async UDF requests both execute on the dispatch thread.""" + + async def driver() -> None: + await asyncio.gather( + self._invoke_async('_async_record_udf', [('one',)]), + self._invoke_async('_async_record_udf', [('two',)]), + self._invoke_async('_async_record_udf', [('three',)]), + ) + + asyncio.run(driver()) + + dispatch_thread = _get_async_dispatch_thread() + assert dispatch_thread is not None + with _dispatch_observation_lock: + for tag in ('one', 'two', 'three'): + self.assertEqual( + _dispatch_observation[tag], dispatch_thread.ident, + f'tag {tag} ran on wrong thread', + ) + + def test_concurrent_async_requests_do_not_serialize(self) -> None: + """Concurrent async UDF requests run in parallel on the dispatch loop; + a new request does not wait for in-flight ones.""" + n = 4 + per_call_sleep = 0.4 + + async def driver() -> None: + await asyncio.gather(*[ + self._invoke_async('_async_slow_udf', [(f'r{i}',)]) + for i in range(n) + ]) + + start = time.monotonic() + asyncio.run(driver()) + elapsed = time.monotonic() - start + + # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. + self.assertLess(elapsed, per_call_sleep * 2) + + def test_new_async_request_runs_during_in_flight_request(self) -> None: + """An async request arriving while another is still running gets + dispatched onto the async thread immediately and finishes before + the in-flight one — i.e., a new request does not wait for any + existing async request to be served.""" + + async def driver() -> None: + long_call = asyncio.create_task( + self._invoke_async('_async_long_udf', [('long',)]), + ) + # Spin until the long request has actually started executing + # on the dispatch thread, so any new dispatch we fire after + # this point is genuinely "during" an in-flight request. + for _ in range(200): + await asyncio.sleep(0.01) + with _dispatch_observation_lock: + if 'long' in _dispatch_started_at: + break + self.assertIn('long', _dispatch_started_at) + + t_call = time.monotonic() + await self._invoke_async('_async_record_udf', [('quick',)]) + t_returned = time.monotonic() + await long_call + + asyncio.run(driver()) + + with _dispatch_observation_lock: + long_started = _dispatch_started_at['long'] + long_finished = _dispatch_finished_at['long'] + quick_started = _dispatch_started_at['quick'] + quick_finished = _dispatch_finished_at['quick'] + + # quick must have started AFTER long started (it was fired later) + # but BEFORE long finished, and itself finished before long did. + self.assertGreater(quick_started, long_started) + self.assertLess(quick_started, long_finished) + self.assertLess(quick_finished, long_finished) + + # Sanity: the long UDF body really did span the long sleep. + self.assertGreaterEqual(long_finished - long_started, 0.9) + + if __name__ == '__main__': unittest.main() From 4448a6989669aeb055c863260d8ca0de7a1ca0b5 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Wed, 10 Jun 2026 15:49:28 +0530 Subject: [PATCH 14/19] more logs --- singlestoredb/ai/embeddings.py | 224 ++++++++++++++++++++++++++++ singlestoredb/functions/ext/asgi.py | 24 ++- 2 files changed, 242 insertions(+), 6 deletions(-) diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index aba8e1c47..5a023adc5 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -1,5 +1,10 @@ +import contextvars +import logging import os +import time +import uuid from typing import Any +from typing import AsyncIterator from typing import Callable from typing import Optional from typing import Union @@ -9,6 +14,160 @@ from singlestoredb import manage_workspaces from singlestoredb.management.inference_api import InferenceAPIInfo + +# Per-task trace id propagated into the embeddings HTTP transport. +# +# Callers (e.g. an EMBED_TEXT UDF) can do ``http_trace_id.set("")`` +# right before invoking ``aembed_documents``. The :class:`TracingAsyncTransport` +# reads this var inside ``handle_async_request`` and stamps every log line +# with it, so each per-stage HTTP timing line can be correlated back to the +# UDF request that produced it. +# +# ContextVars are per-asyncio-Task: even with many concurrent EMBED_TEXT +# coroutines sharing a single dispatch loop, each call has its own private +# context, so a set() in one task does not leak into another. +http_trace_id: 'contextvars.ContextVar[str]' = contextvars.ContextVar( + 'singlestoredb_embeddings_http_trace_id', default='-', +) + +_http_log = logging.getLogger('singlestoredb.ai.embeddings.http') + + +class HttpTraceIdFilter(logging.Filter): + """ + Stamps every log record with the current :data:`http_trace_id` value + under ``record.trace_id``. + + Attach this to handlers (or loggers) for ``httpx`` / ``httpcore`` so + that their low-level per-stage log lines (``connect_tcp.started``, + ``start_tls.started``, ``send_request_body.complete`` etc.) carry the + same trace id the caller stamped before invoking ``aembed_documents``. + + Why this works: ``httpcore`` runs inside the same ``asyncio.Task`` as + the caller (it's just deeper in the ``await`` chain), and ContextVars + are per-Task, so ``http_trace_id.get()`` inside ``filter()`` returns + the value set by the caller for that specific request. Each concurrent + EMBED_TEXT call has its own value and they do not bleed across. + """ + + def filter(self, record: logging.LogRecord) -> bool: + record.trace_id = http_trace_id.get() + return True + + +def enable_http_debug_logging(level: int = logging.DEBUG) -> None: + """ + Turn on per-stage httpx / httpcore logging. + + This is very verbose (one log line per TCP connect, TLS handshake, + header send, body send, response header read, response body read, etc.) + but is the fastest way to pinpoint which network phase a stuck embedding + request is sitting in. Enable on demand, e.g. by setting the + ``SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG=1`` env var or by calling this + function directly. + """ + for name in ( + 'httpx', + 'httpcore', + 'httpcore.connection', + 'httpcore.http11', + 'httpcore.http2', + 'httpcore.proxy', + ): + logging.getLogger(name).setLevel(level) + + +if os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG', '', +).lower() in ('1', 'true', 'yes'): + enable_http_debug_logging() + + +class TracingAsyncTransport(httpx.AsyncBaseTransport): + """ + Wraps another :class:`httpx.AsyncBaseTransport` and logs per-stage + timings for every request, stamped with the current :data:`http_trace_id` + context value. + + Emits three log lines per request: + + 1. ``->`` when the request is handed to the inner transport, with the + outgoing body size. + 2. ``<- headers`` when the response headers arrive (gives time-to-first- + byte, i.e. the gap that captures DNS + TCP + TLS + upload + upstream + processing). + 3. ``<- body`` when the response body is fully consumed (gives separate + body-download elapsed and total elapsed). + + The 1->2 gap vs the 2->3 gap is what tells you whether a hang is on the + request/upstream side or on the response/download side. + """ + + def __init__(self, inner: httpx.AsyncBaseTransport) -> None: + self._inner = inner + + async def handle_async_request( + self, request: httpx.Request, + ) -> httpx.Response: + tid = http_trace_id.get() + rid = uuid.uuid4().hex[:6] + try: + body_len = len(request.content or b'') + except httpx.RequestNotRead: + body_len = -1 + t0 = time.perf_counter() + _http_log.info( + '[%s/%s] -> %s %s body=%dB', + tid, rid, request.method, request.url, body_len, + ) + try: + response = await self._inner.handle_async_request(request) + except BaseException as e: + _http_log.error( + '[%s/%s] xx EXC after %.3fs: %r', + tid, rid, time.perf_counter() - t0, e, + ) + raise + + t_headers = time.perf_counter() + _http_log.info( + '[%s/%s] <- headers status=%d ttfb=%.3fs', + tid, rid, response.status_code, t_headers - t0, + ) + + # Wrap the body stream so we also time how long the body download + # itself takes. Captures `tid`, `rid`, `t0`, `t_headers` by closure + # so the log line is correctly correlated even though body consumption + # happens later (after handle_async_request has returned). + original_stream = response.stream + + class _TimingStream(httpx.AsyncByteStream): + + async def __aiter__(self) -> AsyncIterator[bytes]: + total = 0 + t_body_start = time.perf_counter() + try: + async for chunk in original_stream: + total += len(chunk) + yield chunk + finally: + t_body_end = time.perf_counter() + _http_log.info( + '[%s/%s] <- body bytes=%d body_elapsed=%.3fs ' + 'total=%.3fs', + tid, rid, total, + t_body_end - t_body_start, t_body_end - t0, + ) + + async def aclose(self) -> None: + await original_stream.aclose() + + response.stream = _TimingStream() + return response + + async def aclose(self) -> None: + await self._inner.aclose() + try: from langchain_openai import OpenAIEmbeddings except ImportError: @@ -34,6 +193,7 @@ def SingleStoreEmbeddingsFactory( model_name: str, api_key: Optional[str] = None, http_client: Optional[httpx.Client] = None, + http_async_client: Optional[httpx.AsyncClient] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, base_url: Optional[str] = None, hosting_platform: Optional[str] = None, @@ -152,6 +312,70 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: ) if http_client is not None: openai_kwargs['http_client'] = http_client + + if http_async_client is None: + # Explicit timeouts: without these, httpx falls back to its 5s + # default at the client level, but the OpenAI SDK overrides that + # with a per-request 600s read timeout, so a stalled response can + # sit on the socket for ~10 minutes before httpx notices. We use a + # tighter read timeout so a dead/half-open connection fails fast + # instead of waiting for the application-level defensive timeout + # (e.g. EMBED_TEXT's asyncio.wait_for) to fire. + client_timeout = httpx.Timeout( + connect=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_CONNECT_TIMEOUT', '10', + ), + ), + read=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_READ_TIMEOUT', '60', + ), + ), + write=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_WRITE_TIMEOUT', '30', + ), + ), + pool=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_POOL_TIMEOUT', '10', + ), + ), + ) + # Allow connection reuse. The previous configuration + # (max_keepalive_connections=0) forced a fresh TCP+TLS handshake + # for every request, which under heavy concurrency churns sockets + # and occasionally yields one connection that the upstream accepts + # but never finishes responding on. + client_limits = httpx.Limits( + max_connections=int( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_MAX_CONNECTIONS', '64', + ), + ), + max_keepalive_connections=int( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_MAX_KEEPALIVE', '16', + ), + ), + keepalive_expiry=float( + os.environ.get( + 'SINGLESTOREDB_EMBEDDINGS_KEEPALIVE_EXPIRY', '30', + ), + ), + ) + http_async_client = httpx.AsyncClient( + timeout=client_timeout, + limits=client_limits, + transport=TracingAsyncTransport( + httpx.AsyncHTTPTransport( + limits=client_limits, + ), + ), + ) + openai_kwargs['http_async_client'] = http_async_client + return OpenAIEmbeddings( **openai_kwargs, **kwargs, diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e3967cc7a..dafa7650d 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1411,9 +1411,15 @@ async def __call__( cancel_event = threading.Event() - with timer('parse_input'): - inputs = input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), + # Parsing the request body can be CPU heavy (esp. for + # rowdat_1 / arrow payloads). Run it in the default + # executor thread pool so the main uvicorn loop is not + # blocked while inputs are being decoded. + load_input = input_handler['load'] # type: ignore + colspec = func_info['colspec'] + async with timer('parse_input'): + inputs = await to_thread( + lambda: load_input(colspec, b''.join(data)), ) # Async user UDFs share a single dedicated event-loop thread @@ -1480,9 +1486,15 @@ async def __call__( elif task is func_task: result.extend(task.result()) - with timer('format_output'): - body = output_handler['dump']( - [x[1] for x in func_info['returns']], *result, # type: ignore + # Serializing the response can also be CPU heavy. Run it + # in the default executor thread pool so the main + # uvicorn loop stays responsive to other connections + # while this request is being encoded. + dump_output = output_handler['dump'] # type: ignore + return_types = [x[1] for x in func_info['returns']] + async with timer('format_output'): + body = await to_thread( + dump_output, return_types, *result, ) await send(output_handler['response']) From b928ab3daaa4829f9c25e17268cc6655e917a2de Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Thu, 11 Jun 2026 10:49:26 +0530 Subject: [PATCH 15/19] pin ip --- singlestoredb/ai/embeddings.py | 73 +++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index 5a023adc5..8529b21ae 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -33,6 +33,21 @@ _http_log = logging.getLogger('singlestoredb.ai.embeddings.http') +def _fmt_addr(addr: Any) -> str: + """Format a ``(host, port, ...)`` sockaddr tuple as ``host:port``.""" + if not addr: + return '?' + try: + return f'{addr[0]}:{addr[1]}' + except Exception: + return str(addr) + + +# Hosts already logged as transport-pinned, so we only emit one line per host +# instead of one per request. +_pin_logged_hosts: 'set[str]' = set() + + class HttpTraceIdFilter(logging.Filter): """ Stamps every log record with the current :data:`http_trace_id` value @@ -111,6 +126,43 @@ async def handle_async_request( ) -> httpx.Response: tid = http_trace_id.get() rid = uuid.uuid4().hex[:6] + + # Deterministic, resolver-independent IP pin. A DNS pin (monkeypatching + # socket.getaddrinfo) is bypassed by httpx/anyio's resolution path, so + # connections still spread across every NLB IP and hit the cross-AZ + # source-port collision. Here we instead dial a fixed IP at the + # transport while preserving TLS SNI + certificate validation + the + # Host header, so every embedding request lands on one endpoint + # regardless of how DNS is resolved. Enabled via + # SINGLESTOREDB_EMBEDDINGS_PIN_IP (optionally restricted to + # SINGLESTOREDB_EMBEDDINGS_PIN_HOST). Read per-request so a pin set + # after this client was constructed (e.g. by the EMBED_TEXT notebook) + # still takes effect. + pin_ip = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_IP') + if pin_ip: + pin_host = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_HOST') + host = request.url.host + if host and host != pin_ip and (not pin_host or host == pin_host): + try: + # Preserve the original Host header (httpx set it to the + # hostname at build time); only ensure it is present. + if 'host' not in request.headers: + request.headers['Host'] = host + request.extensions = { + **request.extensions, 'sni_hostname': host, + } + request.url = request.url.copy_with(host=pin_ip) + if host not in _pin_logged_hosts: + _pin_logged_hosts.add(host) + _http_log.warning( + '[%s/%s] TRANSPORT IP PIN active: dialing %s via %s ' + '(SNI/cert/Host preserved)', tid, rid, host, pin_ip, + ) + except Exception as e: + _http_log.warning( + '[%s/%s] TRANSPORT IP PIN failed: %r', tid, rid, e, + ) + try: body_len = len(request.content or b'') except httpx.RequestNotRead: @@ -130,9 +182,28 @@ async def handle_async_request( raise t_headers = time.perf_counter() + # Surface the actual local/remote socket addresses for this request so + # the (src_ip:src_port -> dst_ip:dst_port) 4-tuple can be correlated + # with node-side tcpdump/conntrack captures. ``src`` here is the pod's + # chosen ephemeral port (the one Cilium masquerade normally preserves), + # and ``dst`` is the resolved endpoint IP actually connected to. + local_addr = remote_addr = None + try: + stream = response.extensions.get('network_stream') + if stream is not None: + local_addr = stream.get_extra_info('client_addr') + remote_addr = stream.get_extra_info('server_addr') + if local_addr is None or remote_addr is None: + sock = stream.get_extra_info('socket') + if sock is not None: + local_addr = local_addr or sock.getsockname() + remote_addr = remote_addr or sock.getpeername() + except Exception: + pass _http_log.info( - '[%s/%s] <- headers status=%d ttfb=%.3fs', + '[%s/%s] <- headers status=%d ttfb=%.3fs src=%s dst=%s', tid, rid, response.status_code, t_headers - t0, + _fmt_addr(local_addr), _fmt_addr(remote_addr), ) # Wrap the body stream so we also time how long the body download From 6933ff344472d6430a43ff0844284f82f23e6f02 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Mon, 15 Jun 2026 18:10:39 +0530 Subject: [PATCH 16/19] Revert "pin ip" This reverts commit b928ab3daaa4829f9c25e17268cc6655e917a2de. --- singlestoredb/ai/embeddings.py | 73 +--------------------------------- 1 file changed, 1 insertion(+), 72 deletions(-) diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index 8529b21ae..5a023adc5 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -33,21 +33,6 @@ _http_log = logging.getLogger('singlestoredb.ai.embeddings.http') -def _fmt_addr(addr: Any) -> str: - """Format a ``(host, port, ...)`` sockaddr tuple as ``host:port``.""" - if not addr: - return '?' - try: - return f'{addr[0]}:{addr[1]}' - except Exception: - return str(addr) - - -# Hosts already logged as transport-pinned, so we only emit one line per host -# instead of one per request. -_pin_logged_hosts: 'set[str]' = set() - - class HttpTraceIdFilter(logging.Filter): """ Stamps every log record with the current :data:`http_trace_id` value @@ -126,43 +111,6 @@ async def handle_async_request( ) -> httpx.Response: tid = http_trace_id.get() rid = uuid.uuid4().hex[:6] - - # Deterministic, resolver-independent IP pin. A DNS pin (monkeypatching - # socket.getaddrinfo) is bypassed by httpx/anyio's resolution path, so - # connections still spread across every NLB IP and hit the cross-AZ - # source-port collision. Here we instead dial a fixed IP at the - # transport while preserving TLS SNI + certificate validation + the - # Host header, so every embedding request lands on one endpoint - # regardless of how DNS is resolved. Enabled via - # SINGLESTOREDB_EMBEDDINGS_PIN_IP (optionally restricted to - # SINGLESTOREDB_EMBEDDINGS_PIN_HOST). Read per-request so a pin set - # after this client was constructed (e.g. by the EMBED_TEXT notebook) - # still takes effect. - pin_ip = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_IP') - if pin_ip: - pin_host = os.environ.get('SINGLESTOREDB_EMBEDDINGS_PIN_HOST') - host = request.url.host - if host and host != pin_ip and (not pin_host or host == pin_host): - try: - # Preserve the original Host header (httpx set it to the - # hostname at build time); only ensure it is present. - if 'host' not in request.headers: - request.headers['Host'] = host - request.extensions = { - **request.extensions, 'sni_hostname': host, - } - request.url = request.url.copy_with(host=pin_ip) - if host not in _pin_logged_hosts: - _pin_logged_hosts.add(host) - _http_log.warning( - '[%s/%s] TRANSPORT IP PIN active: dialing %s via %s ' - '(SNI/cert/Host preserved)', tid, rid, host, pin_ip, - ) - except Exception as e: - _http_log.warning( - '[%s/%s] TRANSPORT IP PIN failed: %r', tid, rid, e, - ) - try: body_len = len(request.content or b'') except httpx.RequestNotRead: @@ -182,28 +130,9 @@ async def handle_async_request( raise t_headers = time.perf_counter() - # Surface the actual local/remote socket addresses for this request so - # the (src_ip:src_port -> dst_ip:dst_port) 4-tuple can be correlated - # with node-side tcpdump/conntrack captures. ``src`` here is the pod's - # chosen ephemeral port (the one Cilium masquerade normally preserves), - # and ``dst`` is the resolved endpoint IP actually connected to. - local_addr = remote_addr = None - try: - stream = response.extensions.get('network_stream') - if stream is not None: - local_addr = stream.get_extra_info('client_addr') - remote_addr = stream.get_extra_info('server_addr') - if local_addr is None or remote_addr is None: - sock = stream.get_extra_info('socket') - if sock is not None: - local_addr = local_addr or sock.getsockname() - remote_addr = remote_addr or sock.getpeername() - except Exception: - pass _http_log.info( - '[%s/%s] <- headers status=%d ttfb=%.3fs src=%s dst=%s', + '[%s/%s] <- headers status=%d ttfb=%.3fs', tid, rid, response.status_code, t_headers - t0, - _fmt_addr(local_addr), _fmt_addr(remote_addr), ) # Wrap the body stream so we also time how long the body download From 3e4695e83c0acbb16d8e66e3dd1b60670ed6c588 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Mon, 15 Jun 2026 20:47:04 +0530 Subject: [PATCH 17/19] rewrite --- pyproject.toml | 2 +- singlestoredb/ai/embeddings.py | 163 --------------------- singlestoredb/functions/ext/asgi.py | 157 ++++---------------- singlestoredb/tests/test_udf_event_loop.py | 155 +------------------- 4 files changed, 38 insertions(+), 439 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f57263e3..bb436d8ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "singlestoredb" -version = "1.16.11rc1+byasaini" +version = "1.16.10" description = "Interface to the SingleStoreDB database and workspace management APIs" readme = {file = "README.md", content-type = "text/markdown"} license = {text = "Apache-2.0"} diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index 5a023adc5..68dfc9d96 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -1,10 +1,5 @@ -import contextvars -import logging import os -import time -import uuid from typing import Any -from typing import AsyncIterator from typing import Callable from typing import Optional from typing import Union @@ -15,159 +10,6 @@ from singlestoredb.management.inference_api import InferenceAPIInfo -# Per-task trace id propagated into the embeddings HTTP transport. -# -# Callers (e.g. an EMBED_TEXT UDF) can do ``http_trace_id.set("")`` -# right before invoking ``aembed_documents``. The :class:`TracingAsyncTransport` -# reads this var inside ``handle_async_request`` and stamps every log line -# with it, so each per-stage HTTP timing line can be correlated back to the -# UDF request that produced it. -# -# ContextVars are per-asyncio-Task: even with many concurrent EMBED_TEXT -# coroutines sharing a single dispatch loop, each call has its own private -# context, so a set() in one task does not leak into another. -http_trace_id: 'contextvars.ContextVar[str]' = contextvars.ContextVar( - 'singlestoredb_embeddings_http_trace_id', default='-', -) - -_http_log = logging.getLogger('singlestoredb.ai.embeddings.http') - - -class HttpTraceIdFilter(logging.Filter): - """ - Stamps every log record with the current :data:`http_trace_id` value - under ``record.trace_id``. - - Attach this to handlers (or loggers) for ``httpx`` / ``httpcore`` so - that their low-level per-stage log lines (``connect_tcp.started``, - ``start_tls.started``, ``send_request_body.complete`` etc.) carry the - same trace id the caller stamped before invoking ``aembed_documents``. - - Why this works: ``httpcore`` runs inside the same ``asyncio.Task`` as - the caller (it's just deeper in the ``await`` chain), and ContextVars - are per-Task, so ``http_trace_id.get()`` inside ``filter()`` returns - the value set by the caller for that specific request. Each concurrent - EMBED_TEXT call has its own value and they do not bleed across. - """ - - def filter(self, record: logging.LogRecord) -> bool: - record.trace_id = http_trace_id.get() - return True - - -def enable_http_debug_logging(level: int = logging.DEBUG) -> None: - """ - Turn on per-stage httpx / httpcore logging. - - This is very verbose (one log line per TCP connect, TLS handshake, - header send, body send, response header read, response body read, etc.) - but is the fastest way to pinpoint which network phase a stuck embedding - request is sitting in. Enable on demand, e.g. by setting the - ``SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG=1`` env var or by calling this - function directly. - """ - for name in ( - 'httpx', - 'httpcore', - 'httpcore.connection', - 'httpcore.http11', - 'httpcore.http2', - 'httpcore.proxy', - ): - logging.getLogger(name).setLevel(level) - - -if os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_HTTP_DEBUG', '', -).lower() in ('1', 'true', 'yes'): - enable_http_debug_logging() - - -class TracingAsyncTransport(httpx.AsyncBaseTransport): - """ - Wraps another :class:`httpx.AsyncBaseTransport` and logs per-stage - timings for every request, stamped with the current :data:`http_trace_id` - context value. - - Emits three log lines per request: - - 1. ``->`` when the request is handed to the inner transport, with the - outgoing body size. - 2. ``<- headers`` when the response headers arrive (gives time-to-first- - byte, i.e. the gap that captures DNS + TCP + TLS + upload + upstream - processing). - 3. ``<- body`` when the response body is fully consumed (gives separate - body-download elapsed and total elapsed). - - The 1->2 gap vs the 2->3 gap is what tells you whether a hang is on the - request/upstream side or on the response/download side. - """ - - def __init__(self, inner: httpx.AsyncBaseTransport) -> None: - self._inner = inner - - async def handle_async_request( - self, request: httpx.Request, - ) -> httpx.Response: - tid = http_trace_id.get() - rid = uuid.uuid4().hex[:6] - try: - body_len = len(request.content or b'') - except httpx.RequestNotRead: - body_len = -1 - t0 = time.perf_counter() - _http_log.info( - '[%s/%s] -> %s %s body=%dB', - tid, rid, request.method, request.url, body_len, - ) - try: - response = await self._inner.handle_async_request(request) - except BaseException as e: - _http_log.error( - '[%s/%s] xx EXC after %.3fs: %r', - tid, rid, time.perf_counter() - t0, e, - ) - raise - - t_headers = time.perf_counter() - _http_log.info( - '[%s/%s] <- headers status=%d ttfb=%.3fs', - tid, rid, response.status_code, t_headers - t0, - ) - - # Wrap the body stream so we also time how long the body download - # itself takes. Captures `tid`, `rid`, `t0`, `t_headers` by closure - # so the log line is correctly correlated even though body consumption - # happens later (after handle_async_request has returned). - original_stream = response.stream - - class _TimingStream(httpx.AsyncByteStream): - - async def __aiter__(self) -> AsyncIterator[bytes]: - total = 0 - t_body_start = time.perf_counter() - try: - async for chunk in original_stream: - total += len(chunk) - yield chunk - finally: - t_body_end = time.perf_counter() - _http_log.info( - '[%s/%s] <- body bytes=%d body_elapsed=%.3fs ' - 'total=%.3fs', - tid, rid, total, - t_body_end - t_body_start, t_body_end - t0, - ) - - async def aclose(self) -> None: - await original_stream.aclose() - - response.stream = _TimingStream() - return response - - async def aclose(self) -> None: - await self._inner.aclose() - try: from langchain_openai import OpenAIEmbeddings except ImportError: @@ -368,11 +210,6 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: http_async_client = httpx.AsyncClient( timeout=client_timeout, limits=client_limits, - transport=TracingAsyncTransport( - httpx.AsyncHTTPTransport( - limits=client_limits, - ), - ), ) openai_kwargs['http_async_client'] = http_async_client diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index dafa7650d..a4ab6dac7 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -136,90 +136,18 @@ async def _cancellable_run( return task.result() -# Each `to_thread` worker thread owns a long-lived event loop reused across -# requests, so loop-bound resources (HTTP pools, DB sessions, sockets) can -# survive between calls handled by the same thread. -# -# This per-thread loop is only used for SYNC user UDFs: a sync UDF blocks -# its worker thread for the duration of the call, so giving each worker -# thread its own loop avoids cross-thread loop sharing for those calls. -_thread_local = threading.local() -_loop_registry: 'Set[asyncio.AbstractEventLoop]' = set() -_loop_registry_lock = threading.Lock() - - -def _get_thread_loop() -> asyncio.AbstractEventLoop: - """Return (creating if needed) the calling thread's persistent loop.""" - loop = getattr(_thread_local, 'loop', None) - if loop is None or loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - _thread_local.loop = loop - with _loop_registry_lock: - _loop_registry.add(loop) - return loop - - -def _run_on_thread_loop(coro: Any) -> Any: - """ - Run ``coro`` on the calling thread's persistent loop. - - The loop is never closed between calls, so loop-bound resources (e.g. - httpx keep-alive pools) survive across requests and the deferred - "Event loop is closed" errors thrown by httpx/anyio at teardown do not - occur. - - Caveat: tasks the user code spawns via ``asyncio.create_task`` and - leaves running outlive the current call too. That is the price of - keeping shared resources alive; ``cancel_event`` does not reach them. - """ - loop = _get_thread_loop() - return loop.run_until_complete(coro) - - -def _shutdown_thread_loops() -> None: - """Best-effort cleanup of all persistent worker-thread loops at exit.""" - with _loop_registry_lock: - loops = list(_loop_registry) - _loop_registry.clear() - - for loop in loops: - if loop.is_closed(): - continue - try: - # Owning thread is no longer running the loop; safe to drive - # teardown from this (exiting) thread. - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.run_until_complete(loop.shutdown_default_executor()) - except Exception: - pass - finally: - try: - loop.close() - except Exception: - pass - - -atexit.register(_shutdown_thread_loops) - - # Dedicated event loop used for ALL async UDF requests. # -# Async UDFs commonly create resources that are bound to the event loop -# they are first used on (httpx connection pools, async DB clients, anyio -# streams, ...). Dispatching async requests to ad-hoc worker threads — -# each with its own loop — produced " is bound -# to a different event loop" errors when those cached resources were -# reused by a request that landed on a different worker thread. -# -# Routing every async UDF onto a single dedicated loop fixes that, and -# also gives true concurrency across requests: ``run_coroutine_threadsafe`` -# schedules each new coroutine immediately so that incoming requests do -# not queue behind in-flight ones. +# Async UDFs commonly create resources bound to the event loop they are +# first used on (httpx pools, async DB clients, anyio streams, ...). Routing +# every async UDF onto one dedicated loop lets those resources be reused +# safely across requests and avoids the "bound to a different event loop" +# errors seen when requests land on different ad-hoc worker threads. +# ``run_coroutine_threadsafe`` schedules each coroutine immediately, so +# requests run concurrently rather than queuing behind in-flight ones. # -# Sync UDFs intentionally still go through the worker-thread / per-thread -# loop path above: a sync UDF would block this dedicated loop and starve -# every other in-flight async request. +# Sync UDFs instead run in a worker thread (one ``asyncio.run`` per call): +# a sync UDF would block this shared loop and starve other async requests. _async_dispatch_loop: 'Optional[asyncio.AbstractEventLoop]' = None _async_dispatch_thread: 'Optional[threading.Thread]' = None _async_dispatch_lock = threading.Lock() @@ -229,10 +157,8 @@ def _get_async_dispatch_loop() -> asyncio.AbstractEventLoop: """ Return (lazily creating) the singleton async-dispatch event loop. - The loop is owned by a dedicated daemon thread that runs ``run_forever`` - for the lifetime of the process. All async UDF coroutines are scheduled - on this loop so that loop-bound resources can be safely reused across - requests. + Owned by a dedicated daemon thread running ``run_forever`` for the life + of the process (see the module-level notes above for the rationale). """ global _async_dispatch_loop, _async_dispatch_thread @@ -291,17 +217,11 @@ def _get_async_dispatch_thread() -> 'Optional[threading.Thread]': async def _dispatch_to_async_loop(coro: Any) -> Any: """ - Schedule ``coro`` on the dedicated async-dispatch loop and await its result. + Schedule ``coro`` on the dedicated async-dispatch loop and await it. - The coroutine begins running immediately on the dispatch loop — it does - NOT wait for any earlier in-flight async UDF to complete — so concurrent - async requests run in parallel on a single shared loop. - - Cancellation of the awaiting task on the caller's loop is propagated - best-effort to the work scheduled on the dispatch loop. The user code's - ``cancel_event`` (set by the request handler on timeout / disconnect) - remains the authoritative cancellation signal because it is observed by - ``_cancellable_run`` from inside the dispatch loop. + Cancelling the awaiting task best-effort cancels the scheduled work, but + ``cancel_event`` (observed by ``_cancellable_run`` from inside the + dispatch loop) remains the authoritative cancellation signal. """ loop = _get_async_dispatch_loop() cf = asyncio.run_coroutine_threadsafe(coro, loop) @@ -1411,24 +1331,14 @@ async def __call__( cancel_event = threading.Event() - # Parsing the request body can be CPU heavy (esp. for - # rowdat_1 / arrow payloads). Run it in the default - # executor thread pool so the main uvicorn loop is not - # blocked while inputs are being decoded. - load_input = input_handler['load'] # type: ignore - colspec = func_info['colspec'] - async with timer('parse_input'): - inputs = await to_thread( - lambda: load_input(colspec, b''.join(data)), + with timer('parse_input'): + inputs = input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), ) - # Async user UDFs share a single dedicated event-loop thread - # so that loop-bound resources (httpx pools, async clients, - # ...) can be reused across requests; new requests are - # scheduled immediately and run concurrently on that loop. - # Sync user UDFs continue to use the worker-thread pool (one - # persistent loop per thread) because a sync call would - # block the shared dispatch loop and starve other requests. + # Async UDFs run on the dedicated dispatch loop; sync UDFs run + # in a worker thread (one asyncio.run per call) so they cannot + # block that shared loop (see the module-level notes above). if func_info.get('is_async'): func_task = asyncio.create_task( _dispatch_to_async_loop( @@ -1441,11 +1351,8 @@ async def __call__( else: func_task = asyncio.create_task( to_thread( - lambda: _run_on_thread_loop( - _cancellable_run( - cancel_event, - func(cancel_event, call_timer, *inputs), - ), + lambda: asyncio.run( + func(cancel_event, call_timer, *inputs), ), ), ) @@ -1464,9 +1371,9 @@ async def __call__( all_tasks, return_when=asyncio.FIRST_COMPLETED, ) - # Signal the worker before awaiting cancellation: cancelling - # func_task only flips its asyncio wrapper, not the executor - # work; only cancel_event reaches the worker loop. + # Signal cancellation before awaiting: cancelling func_task + # only unwinds its asyncio wrapper on this loop, not the work + # running off-thread; cancel_event is what actually reaches it. if func_task in pending: cancel_event.set() @@ -1486,15 +1393,9 @@ async def __call__( elif task is func_task: result.extend(task.result()) - # Serializing the response can also be CPU heavy. Run it - # in the default executor thread pool so the main - # uvicorn loop stays responsive to other connections - # while this request is being encoded. - dump_output = output_handler['dump'] # type: ignore - return_types = [x[1] for x in func_info['returns']] - async with timer('format_output'): - body = await to_thread( - dump_output, return_types, *result, + with timer('format_output'): + body = output_handler['dump']( + [x[1] for x in func_info['returns']], *result, # type: ignore ) await send(output_handler['response']) diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py index 7c89c397a..365b3d426 100644 --- a/singlestoredb/tests/test_udf_event_loop.py +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -1,4 +1,4 @@ -"""Tests for the async UDF persistent per-thread event loop.""" +"""Tests for the dedicated async UDF dispatch event loop.""" import asyncio import contextvars import json as jsonlib @@ -16,8 +16,6 @@ from ..functions.ext.asgi import _dispatch_to_async_loop from ..functions.ext.asgi import _get_async_dispatch_loop from ..functions.ext.asgi import _get_async_dispatch_thread -from ..functions.ext.asgi import _get_thread_loop -from ..functions.ext.asgi import _run_on_thread_loop from ..functions.ext.asgi import Application from ..functions.ext.asgi import to_thread @@ -42,7 +40,7 @@ def set_cancel_after_delay() -> None: start = time.monotonic() with self.assertRaises(asyncio.CancelledError): - _run_on_thread_loop( + asyncio.run( _cancellable_run(cancel_event, long_running()), ) elapsed = time.monotonic() - start @@ -61,7 +59,7 @@ async def failing_udf() -> None: raise CustomUDFError('embedding service unavailable') with self.assertRaises(CustomUDFError) as ctx: - _run_on_thread_loop( + asyncio.run( _cancellable_run(cancel_event, failing_udf()), ) self.assertEqual(str(ctx.exception), 'embedding service unavailable') @@ -83,7 +81,7 @@ def set_cancel() -> None: start = time.monotonic() with self.assertRaises(asyncio.CancelledError): - _run_on_thread_loop( + asyncio.run( _cancellable_run(cancel_event, blocked()), ) elapsed = time.monotonic() - start @@ -107,7 +105,7 @@ def read_context_var() -> str: async def run_in_thread() -> str: return await to_thread(read_context_var) - result = _run_on_thread_loop(run_in_thread()) + result = asyncio.run(run_in_thread()) self.assertEqual(result, 'hello_from_parent') self.assertEqual(captured, ['hello_from_parent']) @@ -120,7 +118,7 @@ async def compute() -> int: await asyncio.sleep(0.05) return index * 10 - results[index] = _run_on_thread_loop(compute()) + results[index] = asyncio.run(compute()) threads = [ threading.Thread(target=run_isolated, args=(i,)) @@ -141,7 +139,7 @@ async def sync_as_async() -> int: # Simulates what decorator.py's async_wrapper does for sync UDFs return 42 + 1 - result = _run_on_thread_loop( + result = asyncio.run( _cancellable_run(cancel_event, sync_as_async()), ) self.assertEqual(result, 43) @@ -153,150 +151,13 @@ def test_cancel_event_not_set_on_success(self) -> None: async def quick() -> str: return 'fast' - result = _run_on_thread_loop( + result = asyncio.run( _cancellable_run(cancel_event, quick()), ) self.assertEqual(result, 'fast') self.assertFalse(cancel_event.is_set()) -class TestRunOnThreadLoop(unittest.TestCase): - """Test _run_on_thread_loop reuses a persistent per-thread event loop.""" - - def test_basic_coroutine(self) -> None: - async def simple() -> int: - return 42 - - self.assertEqual(_run_on_thread_loop(simple()), 42) - - def test_loop_reused_across_calls(self) -> None: - """The same loop object is reused for successive calls in a thread.""" - loops: List[asyncio.AbstractEventLoop] = [] - - async def capture_loop() -> bool: - loops.append(asyncio.get_running_loop()) - return True - - _run_on_thread_loop(capture_loop()) - _run_on_thread_loop(capture_loop()) - - self.assertIs(loops[0], loops[1]) - - def test_loop_not_closed_between_calls(self) -> None: - """The persistent loop stays open so resources survive requests.""" - captured: List[asyncio.AbstractEventLoop] = [] - - async def capture_loop() -> bool: - captured.append(asyncio.get_running_loop()) - return True - - _run_on_thread_loop(capture_loop()) - loop = captured[0] - self.assertFalse(loop.is_closed()) - - # Still usable for the next request. - _run_on_thread_loop(capture_loop()) - self.assertFalse(loop.is_closed()) - - def test_async_resource_survives_between_calls(self) -> None: - """An object bound to the loop can be reused on the next call. - - This mirrors caching e.g. an httpx.AsyncClient keyed by the loop and - reusing its connection pool on subsequent requests. - """ - clients: dict = {} - - async def get_or_create_client() -> int: - loop = asyncio.get_running_loop() - if loop not in clients: - clients[loop] = object() - return id(clients[loop]) - - first = _run_on_thread_loop(get_or_create_client()) - second = _run_on_thread_loop(get_or_create_client()) - - self.assertEqual(first, second) - self.assertEqual(len(clients), 1) - - def test_separate_threads_get_separate_loops(self) -> None: - """Each worker thread owns its own persistent loop.""" - loops: List[asyncio.AbstractEventLoop] = [] - lock = threading.Lock() - - def run_in_thread() -> None: - async def capture() -> bool: - with lock: - loops.append(asyncio.get_running_loop()) - return True - - _run_on_thread_loop(capture()) - - threads = [threading.Thread(target=run_in_thread) for _ in range(3)] - for t in threads: - t.start() - for t in threads: - t.join() - - self.assertEqual(len(loops), 3) - self.assertEqual(len({id(loop) for loop in loops}), 3) - - def test_get_thread_loop_idempotent(self) -> None: - """_get_thread_loop returns the same loop on repeated calls.""" - def run_in_thread(out: List[asyncio.AbstractEventLoop]) -> None: - out.append(_get_thread_loop()) - out.append(_get_thread_loop()) - - out: List[asyncio.AbstractEventLoop] = [] - t = threading.Thread(target=run_in_thread, args=(out,)) - t.start() - t.join() - - self.assertIs(out[0], out[1]) - - def test_exception_propagates(self) -> None: - async def failing() -> None: - raise ValueError('test error') - - with self.assertRaises(ValueError) as ctx: - _run_on_thread_loop(failing()) - self.assertEqual(str(ctx.exception), 'test error') - - def test_cancellable_run_integration(self) -> None: - """_cancellable_run works on the persistent loop.""" - cancel_event = threading.Event() - - async def slow_func() -> str: - return 'completed' - - result = _run_on_thread_loop( - _cancellable_run(cancel_event, slow_func()), - ) - self.assertEqual(result, 'completed') - - def test_cancellation_via_event(self) -> None: - """Cancellation propagates through the persistent-loop stack.""" - cancel_event = threading.Event() - cancel_event.set() - - async def blocked_func() -> str: - await asyncio.sleep(999) - return 'should not reach' - - with self.assertRaises(asyncio.CancelledError): - _run_on_thread_loop( - _cancellable_run(cancel_event, blocked_func()), - ) - - # Loop must remain usable after a cancelled request. - async def quick() -> str: - return 'ok' - - self.assertEqual( - _run_on_thread_loop(_cancellable_run(threading.Event(), quick())), - 'ok', - ) - - class TestAsyncDispatchLoop(unittest.TestCase): """All async UDF dispatches share a single dedicated event-loop thread. From 0a7a0a9bbbf0c6ae2288788eb9eab0bfed4580cd Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Mon, 15 Jun 2026 21:26:04 +0530 Subject: [PATCH 18/19] rewrite 2 --- singlestoredb/ai/embeddings.py | 61 --- singlestoredb/apps/_python_udfs.py | 5 +- singlestoredb/functions/ext/asgi.py | 18 + singlestoredb/tests/test_udf_event_loop.py | 461 ++++----------------- 4 files changed, 93 insertions(+), 452 deletions(-) diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index 68dfc9d96..aba8e1c47 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -9,7 +9,6 @@ from singlestoredb import manage_workspaces from singlestoredb.management.inference_api import InferenceAPIInfo - try: from langchain_openai import OpenAIEmbeddings except ImportError: @@ -35,7 +34,6 @@ def SingleStoreEmbeddingsFactory( model_name: str, api_key: Optional[str] = None, http_client: Optional[httpx.Client] = None, - http_async_client: Optional[httpx.AsyncClient] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, base_url: Optional[str] = None, hosting_platform: Optional[str] = None, @@ -154,65 +152,6 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: ) if http_client is not None: openai_kwargs['http_client'] = http_client - - if http_async_client is None: - # Explicit timeouts: without these, httpx falls back to its 5s - # default at the client level, but the OpenAI SDK overrides that - # with a per-request 600s read timeout, so a stalled response can - # sit on the socket for ~10 minutes before httpx notices. We use a - # tighter read timeout so a dead/half-open connection fails fast - # instead of waiting for the application-level defensive timeout - # (e.g. EMBED_TEXT's asyncio.wait_for) to fire. - client_timeout = httpx.Timeout( - connect=float( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_CONNECT_TIMEOUT', '10', - ), - ), - read=float( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_READ_TIMEOUT', '60', - ), - ), - write=float( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_WRITE_TIMEOUT', '30', - ), - ), - pool=float( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_POOL_TIMEOUT', '10', - ), - ), - ) - # Allow connection reuse. The previous configuration - # (max_keepalive_connections=0) forced a fresh TCP+TLS handshake - # for every request, which under heavy concurrency churns sockets - # and occasionally yields one connection that the upstream accepts - # but never finishes responding on. - client_limits = httpx.Limits( - max_connections=int( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_MAX_CONNECTIONS', '64', - ), - ), - max_keepalive_connections=int( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_MAX_KEEPALIVE', '16', - ), - ), - keepalive_expiry=float( - os.environ.get( - 'SINGLESTOREDB_EMBEDDINGS_KEEPALIVE_EXPIRY', '30', - ), - ), - ) - http_async_client = httpx.AsyncClient( - timeout=client_timeout, - limits=client_limits, - ) - openai_kwargs['http_async_client'] = http_async_client - return OpenAIEmbeddings( **openai_kwargs, **kwargs, diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index 5c8cf4a73..267a12c63 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -61,8 +61,9 @@ async def run_udf_app( f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', ) - # Increase the timeout so the uvicorn server is not the one closing idle connections. - # Avoiding TIME_WAIT state, rendering the client_port unusable for 60s (default TIME_WAIT duration). + # Raise the keep-alive timeout so uvicorn does not actively close idle + # connections aggressively. Whichever side closes first holds the socket in TIME_WAIT + # (~60s on Linux), so server-initiated closes churn sockets under load. keep_alive_timeout = int( os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index a4ab6dac7..083ea5d03 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -115,6 +115,14 @@ async def to_thread( async def _poll_cancel(cancel_event: threading.Event) -> None: + """ + Return once ``cancel_event`` is set, polling it on the running loop. + + ``threading.Event`` has no awaitable interface, so this bridges the + cross-thread cancellation signal into the dispatch loop by polling on a + short interval. Used as a sibling task to the UDF coroutine in + ``_cancellable_run``. + """ while not cancel_event.is_set(): await asyncio.sleep(0.1) @@ -123,6 +131,16 @@ async def _cancellable_run( cancel_event: threading.Event, coro: Any, ) -> Any: + """ + Run ``coro`` but abandon it if ``cancel_event`` is tripped. + + The coroutine races ``_poll_cancel``; whichever finishes first wins. If + the cancel signal wins, the coroutine's task is cancelled and + ``CancelledError`` is raised, otherwise its result (or exception) is + propagated. This is the authoritative cancellation path for async UDFs: + they run on the shared dispatch loop, where ordinary task cancellation + from the request loop does not reach them. + """ task = asyncio.create_task(coro) cancel_check = asyncio.create_task(_poll_cancel(cancel_event)) done, pending = await asyncio.wait( diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py index 365b3d426..e5b772684 100644 --- a/singlestoredb/tests/test_udf_event_loop.py +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -20,36 +20,35 @@ from ..functions.ext.asgi import to_thread -class TestUDFDispatchEdgeCases(unittest.TestCase): - """Test edge cases in the UDF dispatch stack.""" +class TestCancellableRun(unittest.TestCase): + """Unit tests for ``_cancellable_run`` and the ``to_thread`` helper.""" - def test_timeout_cancels_running_function(self) -> None: - """Cancel event set from timer thread cancels a blocked coroutine.""" + def test_cancel_event_cancels_blocked_coroutine(self) -> None: + """Tripping ``cancel_event`` interrupts a coroutine blocked on I/O. + + The coroutine sleeps far longer than the test could tolerate, so the + test only completes if the cancel signal actually unblocks it. + """ cancel_event = threading.Event() - async def long_running() -> str: + async def blocked() -> str: await asyncio.sleep(999) - return 'should not reach' + return 'unreachable' - def set_cancel_after_delay() -> None: - time.sleep(0.2) + def trip_cancel_soon() -> None: + time.sleep(0.1) cancel_event.set() - timer = threading.Thread(target=set_cancel_after_delay) + timer = threading.Thread(target=trip_cancel_soon) timer.start() - - start = time.monotonic() - with self.assertRaises(asyncio.CancelledError): - asyncio.run( - _cancellable_run(cancel_event, long_running()), - ) - elapsed = time.monotonic() - start - timer.join() - # 0.2s delay + up to 0.1s poll interval + margin - self.assertLess(elapsed, 0.5) - - def test_exception_propagates_through_full_stack(self) -> None: - """User exception propagates unwrapped through the entire dispatch.""" + try: + with self.assertRaises(asyncio.CancelledError): + asyncio.run(_cancellable_run(cancel_event, blocked())) + finally: + timer.join() + + def test_exception_propagates_unwrapped(self) -> None: + """A user exception surfaces unchanged through ``_cancellable_run``.""" cancel_event = threading.Event() class CustomUDFError(Exception): @@ -59,103 +58,34 @@ async def failing_udf() -> None: raise CustomUDFError('embedding service unavailable') with self.assertRaises(CustomUDFError) as ctx: - asyncio.run( - _cancellable_run(cancel_event, failing_udf()), - ) + asyncio.run(_cancellable_run(cancel_event, failing_udf())) self.assertEqual(str(ctx.exception), 'embedding service unavailable') - def test_cancel_event_detected_within_poll_interval(self) -> None: - """Cancellation is detected within one poll cycle (0.1s).""" + def test_successful_run_returns_result_and_leaves_event_unset(self) -> None: + """A successful run returns its value without tripping the event.""" cancel_event = threading.Event() - async def blocked() -> str: - await asyncio.sleep(999) - return 'unreachable' - - def set_cancel() -> None: - time.sleep(0.05) - cancel_event.set() - - timer = threading.Thread(target=set_cancel) - timer.start() + async def quick() -> int: + return 42 + 1 - start = time.monotonic() - with self.assertRaises(asyncio.CancelledError): - asyncio.run( - _cancellable_run(cancel_event, blocked()), - ) - elapsed = time.monotonic() - start - timer.join() - # 0.05s delay + 0.1s poll interval + margin - self.assertLess(elapsed, 0.25) + result = asyncio.run(_cancellable_run(cancel_event, quick())) + self.assertEqual(result, 43) + self.assertFalse(cancel_event.is_set()) def test_context_vars_propagate_through_to_thread(self) -> None: - """Context variables are visible inside to_thread executor.""" + """Context variables are visible inside the ``to_thread`` executor.""" test_var: contextvars.ContextVar[str] = contextvars.ContextVar( 'test_var', ) test_var.set('hello_from_parent') - captured: List[str] = [] def read_context_var() -> str: - val = test_var.get('NOT_FOUND') - captured.append(val) - return val + return test_var.get('NOT_FOUND') async def run_in_thread() -> str: return await to_thread(read_context_var) - result = asyncio.run(run_in_thread()) - self.assertEqual(result, 'hello_from_parent') - self.assertEqual(captured, ['hello_from_parent']) - - def test_concurrent_requests_isolated(self) -> None: - """Parallel executions don't share state.""" - results: List[Any] = [None, None, None] - - def run_isolated(index: int) -> None: - async def compute() -> int: - await asyncio.sleep(0.05) - return index * 10 - - results[index] = asyncio.run(compute()) - - threads = [ - threading.Thread(target=run_isolated, args=(i,)) - for i in range(3) - ] - for t in threads: - t.start() - for t in threads: - t.join() - - self.assertEqual(results, [0, 10, 20]) - - def test_sync_function_through_async_wrapper(self) -> None: - """Synchronous function works when wrapped as async coroutine.""" - cancel_event = threading.Event() - - async def sync_as_async() -> int: - # Simulates what decorator.py's async_wrapper does for sync UDFs - return 42 + 1 - - result = asyncio.run( - _cancellable_run(cancel_event, sync_as_async()), - ) - self.assertEqual(result, 43) - - def test_cancel_event_not_set_on_success(self) -> None: - """Cancel event remains unset after successful execution.""" - cancel_event = threading.Event() - - async def quick() -> str: - return 'fast' - - result = asyncio.run( - _cancellable_run(cancel_event, quick()), - ) - self.assertEqual(result, 'fast') - self.assertFalse(cancel_event.is_set()) + self.assertEqual(asyncio.run(run_in_thread()), 'hello_from_parent') class TestAsyncDispatchLoop(unittest.TestCase): @@ -168,147 +98,54 @@ class TestAsyncDispatchLoop(unittest.TestCase): in-flight requests. """ - def test_dispatch_loop_is_single_dedicated_thread(self) -> None: - """All dispatches run on the same dedicated thread (not the caller).""" + def test_dispatch_uses_single_dedicated_thread_and_loop(self) -> None: + """Every dispatch runs on the one dedicated thread/loop, never the + caller's thread.""" seen_threads: Set[int] = set() + seen_loops: List[asyncio.AbstractEventLoop] = [] async def capture() -> int: seen_threads.add(threading.get_ident()) + seen_loops.append(asyncio.get_running_loop()) return 1 async def run_many() -> None: - await asyncio.gather(*[ - _dispatch_to_async_loop(capture()) for _ in range(8) - ]) + await asyncio.gather( + *[ + _dispatch_to_async_loop(capture()) for _ in range(8) + ], + ) caller_thread = threading.get_ident() asyncio.run(run_many()) + # One dedicated thread, distinct from the caller, and it is the + # singleton dispatch thread. self.assertEqual(len(seen_threads), 1) self.assertNotIn(caller_thread, seen_threads) - # The thread we observed is the singleton dispatch thread. dispatch_thread = _get_async_dispatch_thread() assert dispatch_thread is not None self.assertEqual(seen_threads.pop(), dispatch_thread.ident) - def test_dispatch_loop_is_single_event_loop(self) -> None: - """All dispatches run on the SAME event loop instance.""" - captured: List[asyncio.AbstractEventLoop] = [] + # Every coroutine observed the same loop, and it is the singleton. + self.assertEqual(len(seen_loops), 8) + for loop in seen_loops: + self.assertIs(loop, seen_loops[0]) + self.assertIs(seen_loops[0], _get_async_dispatch_loop()) - async def capture() -> int: - captured.append(asyncio.get_running_loop()) - return 1 + def test_new_requests_run_during_one_in_flight_request(self) -> None: + """Requests fired while a long one is in-flight all start AND finish + before it does, proving they are not serialized behind it. - async def run_many() -> None: - await asyncio.gather(*[ - _dispatch_to_async_loop(capture()) for _ in range(5) - ]) - - asyncio.run(run_many()) - - self.assertEqual(len(captured), 5) - first = captured[0] - for loop in captured: - self.assertIs(loop, first) - self.assertIs(first, _get_async_dispatch_loop()) - - def test_concurrent_dispatches_do_not_serialize(self) -> None: - """Slow dispatches run in parallel on the loop; new requests do not - wait for earlier ones to finish.""" - n = 6 - per_call_sleep = 0.3 - - async def slow() -> str: - await asyncio.sleep(per_call_sleep) - return 'done' - - async def run_many() -> List[str]: - return await asyncio.gather(*[ - _dispatch_to_async_loop(slow()) for _ in range(n) - ]) - - start = time.monotonic() - results = asyncio.run(run_many()) - elapsed = time.monotonic() - start - - self.assertEqual(results, ['done'] * n) - # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. - # Allow generous margin for CI noise. - self.assertLess(elapsed, per_call_sleep * 2) - - def test_new_request_does_not_wait_for_in_flight_request(self) -> None: - """A new async request is submitted to the dispatch thread - immediately and runs while an earlier request is still in-flight. - - This is the explicit guarantee that async UDF dispatch is not - serialized: a request fired AFTER another long one has started - must (a) start before the long one finishes, (b) finish before - the long one finishes, and (c) be submitted with negligible - latency from the caller's perspective. + Assertions compare event ordering (relative timestamps) rather than + absolute wall-clock durations, so they are robust to CI load. """ long_sleep = 1.0 - ts: Dict[str, float] = {} - # Created lazily on the dispatch loop so the asyncio.Event is bound - # to the correct loop. - signals: Dict[str, asyncio.Event] = {} - - async def long_running() -> str: - ts['long_started'] = time.monotonic() - signals['started'] = asyncio.Event() - signals['started'].set() - await asyncio.sleep(long_sleep) - ts['long_finished'] = time.monotonic() - return 'long' - - async def quick() -> str: - ts['quick_started'] = time.monotonic() - await asyncio.sleep(0) - ts['quick_finished'] = time.monotonic() - return 'quick' - - async def driver() -> None: - long_task = asyncio.create_task( - _dispatch_to_async_loop(long_running()), - ) - # Wait until the long task has actually started on the - # dispatch loop. Only after this point can we be sure the - # next dispatch is "during" an in-flight request. - for _ in range(100): - await asyncio.sleep(0.01) - if 'started' in signals and signals['started'].is_set(): - break - self.assertIn('long_started', ts) - - ts['quick_dispatch_called'] = time.monotonic() - quick_result = await _dispatch_to_async_loop(quick()) - ts['quick_dispatch_returned'] = time.monotonic() - self.assertEqual(quick_result, 'quick') - - long_result = await long_task - self.assertEqual(long_result, 'long') - - asyncio.run(driver()) - - # The new request actually overlapped the in-flight one. - self.assertGreater(ts['quick_started'], ts['long_started']) - self.assertLess(ts['quick_started'], ts['long_finished']) - self.assertLess(ts['quick_finished'], ts['long_finished']) - - # Submission of the new request to the dispatch thread is - # non-blocking: the awaiter returned in well under the long - # request's remaining time. - dispatch_latency = ts['quick_dispatch_returned'] \ - - ts['quick_dispatch_called'] - self.assertLess(dispatch_latency, long_sleep / 2) - - def test_many_new_requests_run_during_one_in_flight_request(self) -> None: - """Many new async requests, each fired sequentially while a single - long-running request is in-flight, all start AND finish before the - long one finishes.""" - long_sleep = 1.0 n_quick = 8 ts: Dict[str, float] = {} quick_finished: List[float] = [] + # Created lazily on the dispatch loop so the asyncio.Event is bound + # to the correct loop. signals: Dict[str, asyncio.Event] = {} async def long_running() -> str: @@ -328,26 +165,29 @@ async def driver() -> None: long_task = asyncio.create_task( _dispatch_to_async_loop(long_running()), ) - # Wait for the long task to start. + # Wait until the long task is actually running on the dispatch + # loop before firing the others. for _ in range(100): await asyncio.sleep(0.01) if 'started' in signals and signals['started'].is_set(): break - results = await asyncio.gather(*[ - _dispatch_to_async_loop(quick(i)) for i in range(n_quick) - ]) + results = await asyncio.gather( + *[ + _dispatch_to_async_loop(quick(i)) for i in range(n_quick) + ], + ) self.assertEqual(results, list(range(n_quick))) await long_task asyncio.run(driver()) - # All quick requests finished before the long one did, proving - # they were not queued behind it. + # All quick requests finished between the long request's start and + # finish, proving they were not queued behind it. self.assertEqual(len(quick_finished), n_quick) for finish in quick_finished: - self.assertLess(finish, ts['long_finished']) self.assertGreater(finish, ts['long_started']) + self.assertLess(finish, ts['long_finished']) def test_loop_bound_resource_reused_across_dispatches(self) -> None: """A resource keyed by id(loop) is shared by every async request, @@ -387,36 +227,9 @@ async def driver() -> None: asyncio.run(driver()) self.assertEqual(str(ctx.exception), 'boom') - def test_dispatch_with_cancel_event(self) -> None: - """`_cancellable_run` on the dispatch loop honors the cancel event.""" - cancel_event = threading.Event() - - async def blocked() -> str: - await asyncio.sleep(999) - return 'unreachable' - - def trip_cancel() -> None: - time.sleep(0.1) - cancel_event.set() - - timer = threading.Thread(target=trip_cancel) - timer.start() - - async def driver() -> None: - await _dispatch_to_async_loop( - _cancellable_run(cancel_event, blocked()), - ) - - start = time.monotonic() - with self.assertRaises(asyncio.CancelledError): - asyncio.run(driver()) - elapsed = time.monotonic() - start - timer.join() - # 0.1s delay + 0.1s poll interval + margin - self.assertLess(elapsed, 0.5) - def test_dispatch_loop_survives_after_cancellation(self) -> None: - """The dispatch loop remains usable after a cancelled request.""" + """A cancelled dispatch (via cancel_event) cancels the work on the + loop, and the loop stays usable for later requests.""" cancel_event = threading.Event() cancel_event.set() @@ -448,53 +261,23 @@ async def driver_ok() -> str: # Records the thread that actually executes each UDF body, keyed by tag. _dispatch_observation: Dict[str, int] = {} _dispatch_observation_lock = threading.Lock() -# Per-tag start / finish timestamps, used by the "no waiting for in-flight" -# test below to assert overlap between concurrent requests. -_dispatch_started_at: Dict[str, float] = {} -_dispatch_finished_at: Dict[str, float] = {} def _record(tag: str) -> None: with _dispatch_observation_lock: _dispatch_observation[tag] = threading.get_ident() - _dispatch_started_at[tag] = time.monotonic() - - -def _record_finish(tag: str) -> None: - with _dispatch_observation_lock: - _dispatch_finished_at[tag] = time.monotonic() @udf async def _async_record_udf(tag: str) -> int: _record(tag) await asyncio.sleep(0) - _record_finish(tag) - return len(tag) - - -@udf -async def _async_slow_udf(tag: str) -> int: - _record(tag) - await asyncio.sleep(0.4) - _record_finish(tag) - return len(tag) - - -@udf -async def _async_long_udf(tag: str) -> int: - """Long-running async UDF used to verify that newly arriving async - requests do not have to wait for it to finish.""" - _record(tag) - await asyncio.sleep(1.0) - _record_finish(tag) return len(tag) @udf def _sync_record_udf(tag: str) -> int: _record(tag) - _record_finish(tag) return len(tag) @@ -540,47 +323,28 @@ async def send(msg: Dict[str, Any]) -> None: def _reset_dispatch_observation() -> None: with _dispatch_observation_lock: _dispatch_observation.clear() - _dispatch_started_at.clear() - _dispatch_finished_at.clear() class TestApplicationDispatchRouting(unittest.TestCase): - """End-to-end: Application routes async UDFs to the dispatch loop and - sync UDFs to a worker thread, and concurrent async requests run in - parallel on the dispatch loop.""" + """End-to-end: ``Application`` routes async UDFs to the dispatch loop and + sync UDFs to a worker thread.""" def setUp(self) -> None: _reset_dispatch_observation() self.app = Application( functions=[ _async_record_udf, - _async_slow_udf, - _async_long_udf, _sync_record_udf, ], disable_metrics=True, ) - @staticmethod - def _headers_dict(scope: Dict[str, Any]) -> Dict[bytes, bytes]: - return {k: v for k, v in scope['headers']} - def _invoke(self, name: str, rows: List[Tuple[Any, ...]]) -> List[Dict[str, Any]]: scope, receive, send, sent = _make_invoke_args(name, rows) scope['headers'] = list(scope['headers']) - # Application reads headers as a dict via ``dict(scope['headers'])``, - # which works for our list of tuples. asyncio.run(self.app(scope, receive, send)) return sent - async def _invoke_async( - self, name: str, rows: List[Tuple[Any, ...]], - ) -> List[Dict[str, Any]]: - scope, receive, send, sent = _make_invoke_args(name, rows) - scope['headers'] = list(scope['headers']) - await self.app(scope, receive, send) - return sent - def test_async_udf_runs_on_dispatch_thread(self) -> None: """An async UDF body executes on the dedicated dispatch thread.""" sent = self._invoke('_async_record_udf', [('alpha',)]) @@ -593,7 +357,8 @@ def test_async_udf_runs_on_dispatch_thread(self) -> None: self.assertEqual(_dispatch_observation['alpha'], dispatch_thread.ident) def test_sync_udf_runs_on_a_worker_thread_not_dispatch(self) -> None: - """A sync UDF body runs on a worker thread, NOT the dispatch thread.""" + """A sync UDF body runs on a worker thread, NOT the dispatch thread + and NOT the caller thread.""" # Force the dispatch thread to exist so we can compare ids. _get_async_dispatch_loop() dispatch_thread = _get_async_dispatch_thread() @@ -609,88 +374,6 @@ def test_sync_udf_runs_on_a_worker_thread_not_dispatch(self) -> None: self.assertNotEqual(sync_thread, threading.get_ident()) self.assertNotEqual(sync_thread, dispatch_thread.ident) - def test_concurrent_async_requests_share_dispatch_thread(self) -> None: - """Two concurrent async UDF requests both execute on the dispatch thread.""" - - async def driver() -> None: - await asyncio.gather( - self._invoke_async('_async_record_udf', [('one',)]), - self._invoke_async('_async_record_udf', [('two',)]), - self._invoke_async('_async_record_udf', [('three',)]), - ) - - asyncio.run(driver()) - - dispatch_thread = _get_async_dispatch_thread() - assert dispatch_thread is not None - with _dispatch_observation_lock: - for tag in ('one', 'two', 'three'): - self.assertEqual( - _dispatch_observation[tag], dispatch_thread.ident, - f'tag {tag} ran on wrong thread', - ) - - def test_concurrent_async_requests_do_not_serialize(self) -> None: - """Concurrent async UDF requests run in parallel on the dispatch loop; - a new request does not wait for in-flight ones.""" - n = 4 - per_call_sleep = 0.4 - - async def driver() -> None: - await asyncio.gather(*[ - self._invoke_async('_async_slow_udf', [(f'r{i}',)]) - for i in range(n) - ]) - - start = time.monotonic() - asyncio.run(driver()) - elapsed = time.monotonic() - start - - # Serialized would be ~ n * per_call_sleep. Parallel ~ per_call_sleep. - self.assertLess(elapsed, per_call_sleep * 2) - - def test_new_async_request_runs_during_in_flight_request(self) -> None: - """An async request arriving while another is still running gets - dispatched onto the async thread immediately and finishes before - the in-flight one — i.e., a new request does not wait for any - existing async request to be served.""" - - async def driver() -> None: - long_call = asyncio.create_task( - self._invoke_async('_async_long_udf', [('long',)]), - ) - # Spin until the long request has actually started executing - # on the dispatch thread, so any new dispatch we fire after - # this point is genuinely "during" an in-flight request. - for _ in range(200): - await asyncio.sleep(0.01) - with _dispatch_observation_lock: - if 'long' in _dispatch_started_at: - break - self.assertIn('long', _dispatch_started_at) - - t_call = time.monotonic() - await self._invoke_async('_async_record_udf', [('quick',)]) - t_returned = time.monotonic() - await long_call - - asyncio.run(driver()) - - with _dispatch_observation_lock: - long_started = _dispatch_started_at['long'] - long_finished = _dispatch_finished_at['long'] - quick_started = _dispatch_started_at['quick'] - quick_finished = _dispatch_finished_at['quick'] - - # quick must have started AFTER long started (it was fired later) - # but BEFORE long finished, and itself finished before long did. - self.assertGreater(quick_started, long_started) - self.assertLess(quick_started, long_finished) - self.assertLess(quick_finished, long_finished) - - # Sanity: the long UDF body really did span the long sleep. - self.assertGreaterEqual(long_finished - long_started, 0.9) - if __name__ == '__main__': unittest.main() From 05077b2b4dac5663e4cc97a0075c1393bf0a2953 Mon Sep 17 00:00:00 2001 From: Ankit Saini Date: Tue, 16 Jun 2026 07:53:34 +0530 Subject: [PATCH 19/19] improve comments --- singlestoredb/apps/_python_udfs.py | 6 +++--- singlestoredb/functions/ext/asgi.py | 5 ----- singlestoredb/tests/test_udf_event_loop.py | 14 +++++++------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/singlestoredb/apps/_python_udfs.py b/singlestoredb/apps/_python_udfs.py index 267a12c63..79d6f8a0f 100644 --- a/singlestoredb/apps/_python_udfs.py +++ b/singlestoredb/apps/_python_udfs.py @@ -61,9 +61,9 @@ async def run_udf_app( f'You can only define a maximum of {MAX_UDFS_LIMIT} functions.', ) - # Raise the keep-alive timeout so uvicorn does not actively close idle - # connections aggressively. Whichever side closes first holds the socket in TIME_WAIT - # (~60s on Linux), so server-initiated closes churn sockets under load. + # Raise the keep-alive timeout so uvicorn does not close idle connections so + # eagerly. Whichever side closes first holds the socket in TIME_WAIT (~60s on + # Linux), so server-initiated closes churn sockets under load. keep_alive_timeout = int( os.environ.get('SINGLESTOREDB_UDF_KEEPALIVE_TIMEOUT', '120'), ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 083ea5d03..f3a076ad5 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -228,11 +228,6 @@ def run_loop() -> None: return _async_dispatch_loop -def _get_async_dispatch_thread() -> 'Optional[threading.Thread]': - """Return the dedicated dispatch thread (or ``None`` if not started).""" - return _async_dispatch_thread - - async def _dispatch_to_async_loop(coro: Any) -> Any: """ Schedule ``coro`` on the dedicated async-dispatch loop and await it. diff --git a/singlestoredb/tests/test_udf_event_loop.py b/singlestoredb/tests/test_udf_event_loop.py index e5b772684..a85b124e9 100644 --- a/singlestoredb/tests/test_udf_event_loop.py +++ b/singlestoredb/tests/test_udf_event_loop.py @@ -12,10 +12,10 @@ from typing import Tuple from ..functions import udf +from ..functions.ext import asgi from ..functions.ext.asgi import _cancellable_run from ..functions.ext.asgi import _dispatch_to_async_loop from ..functions.ext.asgi import _get_async_dispatch_loop -from ..functions.ext.asgi import _get_async_dispatch_thread from ..functions.ext.asgi import Application from ..functions.ext.asgi import to_thread @@ -24,7 +24,7 @@ class TestCancellableRun(unittest.TestCase): """Unit tests for ``_cancellable_run`` and the ``to_thread`` helper.""" def test_cancel_event_cancels_blocked_coroutine(self) -> None: - """Tripping ``cancel_event`` interrupts a coroutine blocked on I/O. + """Tripping ``cancel_event`` interrupts a coroutine stuck in a long await. The coroutine sleeps far longer than the test could tolerate, so the test only completes if the cancel signal actually unblocks it. @@ -123,7 +123,7 @@ async def run_many() -> None: # singleton dispatch thread. self.assertEqual(len(seen_threads), 1) self.assertNotIn(caller_thread, seen_threads) - dispatch_thread = _get_async_dispatch_thread() + dispatch_thread = asgi._async_dispatch_thread assert dispatch_thread is not None self.assertEqual(seen_threads.pop(), dispatch_thread.ident) @@ -134,8 +134,8 @@ async def run_many() -> None: self.assertIs(seen_loops[0], _get_async_dispatch_loop()) def test_new_requests_run_during_one_in_flight_request(self) -> None: - """Requests fired while a long one is in-flight all start AND finish - before it does, proving they are not serialized behind it. + """Requests fired while a long one is in-flight all finish before it + does, proving they are not serialized behind it. Assertions compare event ordering (relative timestamps) rather than absolute wall-clock durations, so they are robust to CI load. @@ -351,7 +351,7 @@ def test_async_udf_runs_on_dispatch_thread(self) -> None: statuses = [m for m in sent if m.get('type') == 'http.response.start'] self.assertTrue(statuses and statuses[0]['status'] == 200, sent) - dispatch_thread = _get_async_dispatch_thread() + dispatch_thread = asgi._async_dispatch_thread assert dispatch_thread is not None with _dispatch_observation_lock: self.assertEqual(_dispatch_observation['alpha'], dispatch_thread.ident) @@ -361,7 +361,7 @@ def test_sync_udf_runs_on_a_worker_thread_not_dispatch(self) -> None: and NOT the caller thread.""" # Force the dispatch thread to exist so we can compare ids. _get_async_dispatch_loop() - dispatch_thread = _get_async_dispatch_thread() + dispatch_thread = asgi._async_dispatch_thread assert dispatch_thread is not None sent = self._invoke('_sync_record_udf', [('beta',)])