From 093652866a3f2e47bbb117d7aa7e0727f6d0cec8 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 09:31:54 -0400 Subject: [PATCH 01/11] initial lock refactor and race fixes --- pymongo/asynchronous/pool.py | 187 ++++++++++++++++++++--------------- pymongo/synchronous/pool.py | 187 ++++++++++++++++++++--------------- 2 files changed, 214 insertions(+), 160 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a5d5b28990..7888422d28 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -709,6 +709,11 @@ def __init__( :param options: a PoolOptions instance :param is_sdam: whether to call hello for each new AsyncConnection """ + # Main lock only used to protect updating attributes. + # Avoid any additional work while holding the lock. + # If looping over an attribute, copy the container and do not take the lock. + self.lock = _async_create_lock() + if options.pause_enabled: self.state = PoolState.PAUSED else: @@ -720,10 +725,9 @@ def __init__( # and returned to pool from the left side. Stale sockets removed # from the right side. self.conns: collections.deque[AsyncConnection] = collections.deque() - self.active_contexts: set[_CancellationContext] = set() - self.lock = _async_create_lock() - self._max_connecting_cond = _async_create_condition(self.lock) - self.active_sockets = 0 + # This lock should only be contended by threads adding/removing connections. + self._conns_lock = _async_create_lock() + # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 # Track whether the sockets in this pool are writeable or not. @@ -748,16 +752,19 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _async_create_condition(self.lock) self.requests = 0 + # This lock should only be contended by threads adding/removing self.requests. + self.size_cond = _async_create_condition(_async_create_lock()) self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: self.max_pool_size = float("inf") + # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _async_create_condition(self.lock) self._pending = 0 + # This lock should only be contended by threads adding/removing self._pending. + self._max_connecting_cond = _async_create_condition(_async_create_lock()) self._max_connecting = self.opts.max_connecting self._client_id = client_id # Log before publishing event to prevent potential listener preemption in tests @@ -777,29 +784,41 @@ def __init__( ) # Similar to active_sockets but includes threads in the wait queue. self.operation_count: int = 0 + # This lock should be contended on every operation. + self._operation_count_lock = _async_create_lock() + + self.active_contexts: set[_CancellationContext] = set() + self.active_sockets = 0 # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). self.__pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 + # This lock protects self.active_contexts, self.active_sockets, + # self.__pinned_sockets, self.ncursors, and self.ntxns. + self._active_contexts_lock = _async_create_lock() async def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. - async with self.lock: - if self.state != PoolState.READY: + state_changed = False + if self.state != PoolState.READY: + async with self.lock: self.state = PoolState.READY - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.POOL_READY, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - if self.enabled_for_cmap: - assert self.opts._event_listeners is not None - self.opts._event_listeners.publish_pool_ready(self.address) + state_changed = True + if not state_changed: + return + if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.POOL_READY, + clientId=self._client_id, + serverHost=self.address[0], + serverPort=self.address[1], + ) + if self.enabled_for_cmap: + assert self.opts._event_listeners is not None + self.opts._event_listeners.publish_pool_ready(self.address) @property def closed(self) -> bool: @@ -813,38 +832,46 @@ async def _reset( interrupt_connections: bool = False, ) -> None: old_state = self.state - async with self.size_cond: - if self.closed: - return + if self.closed: + return + is_fork = False + async with self.lock: if self.opts.pause_enabled and pause and not self.opts.load_balanced: old_state, self.state = self.state, PoolState.PAUSED self.gen.inc(service_id) newpid = os.getpid() if self.pid != newpid: self.pid = newpid + is_fork = True + if is_fork: + async with self._active_contexts_lock: self.active_sockets = 0 + async with self._operation_count_lock: self.operation_count = 0 + async with self._conns_lock: if service_id is None: sockets, self.conns = self.conns, collections.deque() - else: - discard: collections.deque = collections.deque() # type: ignore[type-arg] - keep: collections.deque = collections.deque() # type: ignore[type-arg] - for conn in self.conns: - if conn.service_id == service_id: - discard.append(conn) - else: - keep.append(conn) - sockets = discard + if service_id is not None: + discard: collections.deque = collections.deque() # type: ignore[type-arg] + keep: collections.deque = collections.deque() # type: ignore[type-arg] + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) + else: + keep.append(conn) + sockets = discard + async with self._conns_lock: self.conns = keep if close: - self.state = PoolState.CLOSED + async with self.lock: + self.state = PoolState.CLOSED # Clear the wait queue self._max_connecting_cond.notify_all() self.size_cond.notify_all() if interrupt_connections: - for context in self.active_contexts: + for context in self.active_contexts.copy(): context.cancel() listeners = self.opts._event_listeners @@ -903,9 +930,8 @@ async def update_is_writable(self, is_writable: Optional[bool]) -> None: Pool. """ self.is_writable = is_writable - async with self.lock: - for _socket in self.conns: - _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] + for _socket in self.conns: + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] async def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -936,12 +962,9 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - async with self.lock: - while ( - self.conns - and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds - ): - close_conns.append(self.conns.pop()) + conns = self.conns.copy() + while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds: + close_conns.append(conns.pop()) if not _IS_SYNC: await asyncio.gather( *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] @@ -952,12 +975,12 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: await conn.close_conn(ConnectionClosedReason.IDLE) while True: + # There are enough sockets in the pool. + if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: + return + if self.requests >= self.opts.min_pool_size: + return async with self.size_cond: - # There are enough sockets in the pool. - if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: - return - if self.requests >= self.opts.min_pool_size: - return self.requests += 1 incremented = False try: @@ -970,13 +993,14 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: incremented = True conn = await self.connect() close_conn = False - async with self.lock: - # Close connection and return if the pool was reset during - # socket creation or while acquiring the pool lock. - if self.gen.get_overall() != reference_generation: - close_conn = True - if not close_conn: + # Close connection and return if the pool was reset during + # socket creation or while acquiring the pool lock. + if self.gen.get_overall() != reference_generation: + close_conn = True + if not close_conn: + async with self._conns_lock: self.conns.appendleft(conn) + async with self._active_contexts_lock: self.active_contexts.discard(conn.cancel_context) if close_conn: await conn.close_conn(ConnectionClosedReason.STALE) @@ -1015,11 +1039,11 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A Note that the pool does not keep a reference to the socket -- you must call checkin() when you're done with it. """ - async with self.lock: + # Use a temporary context so that interrupt_connections can cancel creating the socket. + tmp_context = _CancellationContext() + async with self._active_contexts_lock: conn_id = self.next_connection_id self.next_connection_id += 1 - # Use a temporary context so that interrupt_connections can cancel creating the socket. - tmp_context = _CancellationContext() self.active_contexts.add(tmp_context) listeners = self.opts._event_listeners @@ -1040,7 +1064,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A networking_interface = await _configured_protocol_interface(self.address, self.opts) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - async with self.lock: + async with self._active_contexts_lock: self.active_contexts.discard(tmp_context) if self.enabled_for_cmap: assert listeners is not None @@ -1065,7 +1089,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A raise conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] - async with self.lock: + async with self._active_contexts_lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) if tmp_context.cancelled: @@ -1082,7 +1106,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A await conn.authenticate() # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as e: - async with self.lock: + async with self._active_contexts_lock: self.active_contexts.discard(conn.cancel_context) if not completed_hello: self._handle_connection_error(e) @@ -1144,7 +1168,7 @@ async def checkout( durationMS=duration, ) try: - async with self.lock: + async with self._active_contexts_lock: self.active_contexts.add(conn.cancel_context) yield conn # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. @@ -1163,11 +1187,11 @@ async def checkout( await self.checkin(conn) raise if conn.pinned_txn: - async with self.lock: + async with self._active_contexts_lock: self.__pinned_sockets.add(conn) self.ntxns += 1 elif conn.pinned_cursor: - async with self.lock: + async with self._active_contexts_lock: self.__pinned_sockets.add(conn) self.ncursors += 1 elif conn.active: @@ -1231,7 +1255,7 @@ async def _get_conn( "Attempted to check out a connection from closed connection pool" ) - async with self.lock: + async with self._operation_count_lock: self.operation_count += 1 # Get a free socket or create one. @@ -1260,9 +1284,9 @@ async def _get_conn( incremented = False emitted_event = False try: - async with self.lock: + async with self._active_contexts_lock: self.active_sockets += 1 - incremented = True + incremented = True while conn is None: # CMAP: we MUST wait for either maxConnecting OR for a socket # to be checked back into the pool. @@ -1280,7 +1304,8 @@ async def _get_conn( self._raise_if_not_ready(checkout_started_time, emit_event=False) try: - conn = self.conns.popleft() + async with self._conns_lock: + conn = self.conns.popleft() except IndexError: self._pending += 1 if conn: # We got a socket from the pool @@ -1301,9 +1326,10 @@ async def _get_conn( await conn.close_conn(ConnectionClosedReason.ERROR) async with self.size_cond: self.requests -= 1 - if incremented: - self.active_sockets -= 1 self.size_cond.notify() + if incremented: + async with self._active_contexts_lock: + self.active_sockets -= 1 if not emitted_event: duration = time.monotonic() - checkout_started_time @@ -1338,9 +1364,9 @@ async def checkin(self, conn: AsyncConnection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) listeners = self.opts._event_listeners - async with self.lock: + async with self._active_contexts_lock: + self.__pinned_sockets.discard(conn) self.active_contexts.discard(conn.cancel_context) if self.enabled_for_cmap: assert listeners is not None @@ -1379,28 +1405,29 @@ async def checkin(self, conn: AsyncConnection) -> None: ) else: close_conn = False - async with self.lock: - # Hold the lock to ensure this section does not race with - # Pool.reset(). - if self.stale_generation(conn.generation, conn.service_id): - close_conn = True - else: - conn.update_last_checkin_time() - conn.update_is_writable(bool(self.is_writable)) + if self.stale_generation(conn.generation, conn.service_id): + close_conn = True + else: + conn.update_last_checkin_time() + conn.update_is_writable(bool(self.is_writable)) + async with self._conns_lock: self.conns.appendleft(conn) + async with self._max_connecting_cond: # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() if close_conn: await conn.close_conn(ConnectionClosedReason.STALE) - async with self.size_cond: + async with self._active_contexts_lock: if txn: self.ntxns -= 1 elif cursor: self.ncursors -= 1 - self.requests -= 1 self.active_sockets -= 1 + async with self._operation_count_lock: self.operation_count -= 1 + async with self.size_cond: + self.requests -= 1 self.size_cond.notify() async def _perished(self, conn: AsyncConnection) -> bool: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 25f2d08fe7..5985fdb318 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -707,6 +707,11 @@ def __init__( :param options: a PoolOptions instance :param is_sdam: whether to call hello for each new Connection """ + # Main lock only used to protect updating attributes. + # Avoid any additional work while holding the lock. + # If looping over an attribute, copy the container and do not take the lock. + self.lock = _create_lock() + if options.pause_enabled: self.state = PoolState.PAUSED else: @@ -718,10 +723,9 @@ def __init__( # and returned to pool from the left side. Stale sockets removed # from the right side. self.conns: collections.deque[Connection] = collections.deque() - self.active_contexts: set[_CancellationContext] = set() - self.lock = _create_lock() - self._max_connecting_cond = _create_condition(self.lock) - self.active_sockets = 0 + # This lock should only be contended by threads adding/removing connections. + self._conns_lock = _create_lock() + # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 # Track whether the sockets in this pool are writeable or not. @@ -746,16 +750,19 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _create_condition(self.lock) self.requests = 0 + # This lock should only be contended by threads adding/removing self.requests. + self.size_cond = _create_condition(_create_lock()) self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: self.max_pool_size = float("inf") + # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _create_condition(self.lock) self._pending = 0 + # This lock should only be contended by threads adding/removing self._pending. + self._max_connecting_cond = _create_condition(_create_lock()) self._max_connecting = self.opts.max_connecting self._client_id = client_id # Log before publishing event to prevent potential listener preemption in tests @@ -775,29 +782,41 @@ def __init__( ) # Similar to active_sockets but includes threads in the wait queue. self.operation_count: int = 0 + # This lock should be contended on every operation. + self._operation_count_lock = _create_lock() + + self.active_contexts: set[_CancellationContext] = set() + self.active_sockets = 0 # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). self.__pinned_sockets: set[Connection] = set() self.ncursors = 0 self.ntxns = 0 + # This lock protects self.active_contexts, self.active_sockets, + # self.__pinned_sockets, self.ncursors, and self.ntxns. + self._active_contexts_lock = _create_lock() def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. - with self.lock: - if self.state != PoolState.READY: + state_changed = False + if self.state != PoolState.READY: + with self.lock: self.state = PoolState.READY - if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _CONNECTION_LOGGER, - message=_ConnectionStatusMessage.POOL_READY, - clientId=self._client_id, - serverHost=self.address[0], - serverPort=self.address[1], - ) - if self.enabled_for_cmap: - assert self.opts._event_listeners is not None - self.opts._event_listeners.publish_pool_ready(self.address) + state_changed = True + if not state_changed: + return + if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + message=_ConnectionStatusMessage.POOL_READY, + clientId=self._client_id, + serverHost=self.address[0], + serverPort=self.address[1], + ) + if self.enabled_for_cmap: + assert self.opts._event_listeners is not None + self.opts._event_listeners.publish_pool_ready(self.address) @property def closed(self) -> bool: @@ -811,38 +830,46 @@ def _reset( interrupt_connections: bool = False, ) -> None: old_state = self.state - with self.size_cond: - if self.closed: - return + if self.closed: + return + is_fork = False + with self.lock: if self.opts.pause_enabled and pause and not self.opts.load_balanced: old_state, self.state = self.state, PoolState.PAUSED self.gen.inc(service_id) newpid = os.getpid() if self.pid != newpid: self.pid = newpid + is_fork = True + if is_fork: + with self._active_contexts_lock: self.active_sockets = 0 + with self._operation_count_lock: self.operation_count = 0 + with self._conns_lock: if service_id is None: sockets, self.conns = self.conns, collections.deque() - else: - discard: collections.deque = collections.deque() # type: ignore[type-arg] - keep: collections.deque = collections.deque() # type: ignore[type-arg] - for conn in self.conns: - if conn.service_id == service_id: - discard.append(conn) - else: - keep.append(conn) - sockets = discard + if service_id is not None: + discard: collections.deque = collections.deque() # type: ignore[type-arg] + keep: collections.deque = collections.deque() # type: ignore[type-arg] + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) + else: + keep.append(conn) + sockets = discard + with self._conns_lock: self.conns = keep if close: - self.state = PoolState.CLOSED + with self.lock: + self.state = PoolState.CLOSED # Clear the wait queue self._max_connecting_cond.notify_all() self.size_cond.notify_all() if interrupt_connections: - for context in self.active_contexts: + for context in self.active_contexts.copy(): context.cancel() listeners = self.opts._event_listeners @@ -901,9 +928,8 @@ def update_is_writable(self, is_writable: Optional[bool]) -> None: Pool. """ self.is_writable = is_writable - with self.lock: - for _socket in self.conns: - _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] + for _socket in self.conns: + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -932,12 +958,9 @@ def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - with self.lock: - while ( - self.conns - and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds - ): - close_conns.append(self.conns.pop()) + conns = self.conns.copy() + while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds: + close_conns.append(conns.pop()) if not _IS_SYNC: asyncio.gather( *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] @@ -948,12 +971,12 @@ def remove_stale_sockets(self, reference_generation: int) -> None: conn.close_conn(ConnectionClosedReason.IDLE) while True: + # There are enough sockets in the pool. + if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: + return + if self.requests >= self.opts.min_pool_size: + return with self.size_cond: - # There are enough sockets in the pool. - if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: - return - if self.requests >= self.opts.min_pool_size: - return self.requests += 1 incremented = False try: @@ -966,13 +989,14 @@ def remove_stale_sockets(self, reference_generation: int) -> None: incremented = True conn = self.connect() close_conn = False - with self.lock: - # Close connection and return if the pool was reset during - # socket creation or while acquiring the pool lock. - if self.gen.get_overall() != reference_generation: - close_conn = True - if not close_conn: + # Close connection and return if the pool was reset during + # socket creation or while acquiring the pool lock. + if self.gen.get_overall() != reference_generation: + close_conn = True + if not close_conn: + with self._conns_lock: self.conns.appendleft(conn) + with self._active_contexts_lock: self.active_contexts.discard(conn.cancel_context) if close_conn: conn.close_conn(ConnectionClosedReason.STALE) @@ -1011,11 +1035,11 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect Note that the pool does not keep a reference to the socket -- you must call checkin() when you're done with it. """ - with self.lock: + # Use a temporary context so that interrupt_connections can cancel creating the socket. + tmp_context = _CancellationContext() + with self._active_contexts_lock: conn_id = self.next_connection_id self.next_connection_id += 1 - # Use a temporary context so that interrupt_connections can cancel creating the socket. - tmp_context = _CancellationContext() self.active_contexts.add(tmp_context) listeners = self.opts._event_listeners @@ -1036,7 +1060,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect networking_interface = _configured_socket_interface(self.address, self.opts) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - with self.lock: + with self._active_contexts_lock: self.active_contexts.discard(tmp_context) if self.enabled_for_cmap: assert listeners is not None @@ -1061,7 +1085,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect raise conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] - with self.lock: + with self._active_contexts_lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) if tmp_context.cancelled: @@ -1078,7 +1102,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect conn.authenticate() # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as e: - with self.lock: + with self._active_contexts_lock: self.active_contexts.discard(conn.cancel_context) if not completed_hello: self._handle_connection_error(e) @@ -1140,7 +1164,7 @@ def checkout( durationMS=duration, ) try: - with self.lock: + with self._active_contexts_lock: self.active_contexts.add(conn.cancel_context) yield conn # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. @@ -1159,11 +1183,11 @@ def checkout( self.checkin(conn) raise if conn.pinned_txn: - with self.lock: + with self._active_contexts_lock: self.__pinned_sockets.add(conn) self.ntxns += 1 elif conn.pinned_cursor: - with self.lock: + with self._active_contexts_lock: self.__pinned_sockets.add(conn) self.ncursors += 1 elif conn.active: @@ -1227,7 +1251,7 @@ def _get_conn( "Attempted to check out a connection from closed connection pool" ) - with self.lock: + with self._operation_count_lock: self.operation_count += 1 # Get a free socket or create one. @@ -1256,9 +1280,9 @@ def _get_conn( incremented = False emitted_event = False try: - with self.lock: + with self._active_contexts_lock: self.active_sockets += 1 - incremented = True + incremented = True while conn is None: # CMAP: we MUST wait for either maxConnecting OR for a socket # to be checked back into the pool. @@ -1276,7 +1300,8 @@ def _get_conn( self._raise_if_not_ready(checkout_started_time, emit_event=False) try: - conn = self.conns.popleft() + with self._conns_lock: + conn = self.conns.popleft() except IndexError: self._pending += 1 if conn: # We got a socket from the pool @@ -1297,9 +1322,10 @@ def _get_conn( conn.close_conn(ConnectionClosedReason.ERROR) with self.size_cond: self.requests -= 1 - if incremented: - self.active_sockets -= 1 self.size_cond.notify() + if incremented: + with self._active_contexts_lock: + self.active_sockets -= 1 if not emitted_event: duration = time.monotonic() - checkout_started_time @@ -1334,9 +1360,9 @@ def checkin(self, conn: Connection) -> None: conn.active = False conn.pinned_txn = False conn.pinned_cursor = False - self.__pinned_sockets.discard(conn) listeners = self.opts._event_listeners - with self.lock: + with self._active_contexts_lock: + self.__pinned_sockets.discard(conn) self.active_contexts.discard(conn.cancel_context) if self.enabled_for_cmap: assert listeners is not None @@ -1375,28 +1401,29 @@ def checkin(self, conn: Connection) -> None: ) else: close_conn = False - with self.lock: - # Hold the lock to ensure this section does not race with - # Pool.reset(). - if self.stale_generation(conn.generation, conn.service_id): - close_conn = True - else: - conn.update_last_checkin_time() - conn.update_is_writable(bool(self.is_writable)) + if self.stale_generation(conn.generation, conn.service_id): + close_conn = True + else: + conn.update_last_checkin_time() + conn.update_is_writable(bool(self.is_writable)) + with self._conns_lock: self.conns.appendleft(conn) + with self._max_connecting_cond: # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() if close_conn: conn.close_conn(ConnectionClosedReason.STALE) - with self.size_cond: + with self._active_contexts_lock: if txn: self.ntxns -= 1 elif cursor: self.ncursors -= 1 - self.requests -= 1 self.active_sockets -= 1 + with self._operation_count_lock: self.operation_count -= 1 + with self.size_cond: + self.requests -= 1 self.size_cond.notify() def _perished(self, conn: Connection) -> bool: From cce3e48c1d7cdbdb0e251f0e5592a87c66c962d2 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 09:35:06 -0400 Subject: [PATCH 02/11] add self to contributors --- doc/contributors.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/contributors.rst b/doc/contributors.rst index 0bd815ce3f..d36731159e 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -108,3 +108,4 @@ The following is a list of people who have contributed to - Steven Silvester (blink1073) - Noah Stapp (NoahStapp) - Cal Jacobson (cj81499) +- Sophia Yang (sophiayangDB) From 8ed0d3577dde7d4be191f2df63539ff1b8c208b7 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 11:01:59 -0400 Subject: [PATCH 03/11] fix unprotected notify --- pymongo/asynchronous/pool.py | 8 +++++--- pymongo/synchronous/pool.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 7888422d28..0c2d1cd723 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -796,7 +796,7 @@ def __init__( self.ncursors = 0 self.ntxns = 0 # This lock protects self.active_contexts, self.active_sockets, - # self.__pinned_sockets, self.ncursors, and self.ntxns. + # self.__pinned_sockets, self.ncursors, self.next_connection_id, and self.ntxns. self._active_contexts_lock = _async_create_lock() async def ready(self) -> None: @@ -867,8 +867,10 @@ async def _reset( async with self.lock: self.state = PoolState.CLOSED # Clear the wait queue - self._max_connecting_cond.notify_all() - self.size_cond.notify_all() + async with self._max_connecting_cond: + self._max_connecting_cond.notify_all() + async with self.size_cond: + self.size_cond.notify_all() if interrupt_connections: for context in self.active_contexts.copy(): diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 5985fdb318..9d78f59941 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -794,7 +794,7 @@ def __init__( self.ncursors = 0 self.ntxns = 0 # This lock protects self.active_contexts, self.active_sockets, - # self.__pinned_sockets, self.ncursors, and self.ntxns. + # self.__pinned_sockets, self.ncursors, self.next_connection_id, and self.ntxns. self._active_contexts_lock = _create_lock() def ready(self) -> None: @@ -865,8 +865,10 @@ def _reset( with self.lock: self.state = PoolState.CLOSED # Clear the wait queue - self._max_connecting_cond.notify_all() - self.size_cond.notify_all() + with self._max_connecting_cond: + self._max_connecting_cond.notify_all() + with self.size_cond: + self.size_cond.notify_all() if interrupt_connections: for context in self.active_contexts.copy(): From b29d615f243470541d86b4333ca4ffebb2fd4ea9 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 11:24:29 -0400 Subject: [PATCH 04/11] indentation fix --- pymongo/asynchronous/pool.py | 22 +++++++++++----------- pymongo/synchronous/pool.py | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 0c2d1cd723..b93230982c 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -863,18 +863,18 @@ async def _reset( async with self._conns_lock: self.conns = keep - if close: - async with self.lock: - self.state = PoolState.CLOSED - # Clear the wait queue - async with self._max_connecting_cond: - self._max_connecting_cond.notify_all() - async with self.size_cond: - self.size_cond.notify_all() + if close: + async with self.lock: + self.state = PoolState.CLOSED + # Clear the wait queue + async with self._max_connecting_cond: + self._max_connecting_cond.notify_all() + async with self.size_cond: + self.size_cond.notify_all() - if interrupt_connections: - for context in self.active_contexts.copy(): - context.cancel() + if interrupt_connections: + for context in self.active_contexts.copy(): + context.cancel() listeners = self.opts._event_listeners # CMAP spec says that close() MUST close sockets before publishing the diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 9d78f59941..89730c21c6 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -861,18 +861,18 @@ def _reset( with self._conns_lock: self.conns = keep - if close: - with self.lock: - self.state = PoolState.CLOSED - # Clear the wait queue - with self._max_connecting_cond: - self._max_connecting_cond.notify_all() - with self.size_cond: - self.size_cond.notify_all() + if close: + with self.lock: + self.state = PoolState.CLOSED + # Clear the wait queue + with self._max_connecting_cond: + self._max_connecting_cond.notify_all() + with self.size_cond: + self.size_cond.notify_all() - if interrupt_connections: - for context in self.active_contexts.copy(): - context.cancel() + if interrupt_connections: + for context in self.active_contexts.copy(): + context.cancel() listeners = self.opts._event_listeners # CMAP spec says that close() MUST close sockets before publishing the From 598a6ef1b8549c0f1ed62fa0c0d0331dc776a254 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 14:43:30 -0400 Subject: [PATCH 05/11] remove needs to occur on self.conns --- pymongo/asynchronous/pool.py | 9 ++++++--- pymongo/synchronous/pool.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index b93230982c..2c543cbe20 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -964,9 +964,12 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - conns = self.conns.copy() - while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds: - close_conns.append(conns.pop()) + async with self.lock: + while ( + self.conns + and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds + ): + close_conns.append(self.conns.pop()) if not _IS_SYNC: await asyncio.gather( *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 89730c21c6..cb469e4270 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -960,9 +960,12 @@ def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - conns = self.conns.copy() - while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds: - close_conns.append(conns.pop()) + with self.lock: + while ( + self.conns + and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds + ): + close_conns.append(self.conns.pop()) if not _IS_SYNC: asyncio.gather( *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value] From 36738bdd1ba49ff83a7c72ad2a55fdef6bc40a11 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Fri, 15 May 2026 15:37:24 -0400 Subject: [PATCH 06/11] changelog --- doc/changelog.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index ed6f23f86d..b3fd7e2f7a 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,13 @@ Changelog ========= +Changes in Version 4.18.0 (2026/XX/XX) +-------------------------------------- +PyMongo 4.18 brings a number of changes including: + +- Improved connection pool throughput under concurrent load by replacing the + single pool lock with fine-grained locks to reduce lock contention. + Changes in Version 4.17.0 (2026/04/20) -------------------------------------- From a2ce38ab367360b5508a277bcb5436e0bd7ad39d Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Mon, 18 May 2026 15:54:48 -0400 Subject: [PATCH 07/11] address pr comments --- pymongo/asynchronous/pool.py | 59 +++++++++++++++++++----------------- pymongo/synchronous/pool.py | 59 +++++++++++++++++++----------------- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 2c543cbe20..111e69682b 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -802,8 +802,8 @@ def __init__( async def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. state_changed = False - if self.state != PoolState.READY: - async with self.lock: + async with self.lock: + if self.state != PoolState.READY: self.state = PoolState.READY state_changed = True if not state_changed: @@ -832,13 +832,12 @@ async def _reset( interrupt_connections: bool = False, ) -> None: old_state = self.state - if self.closed: - return is_fork = False async with self.lock: + if self.closed: + return if self.opts.pause_enabled and pause and not self.opts.load_balanced: old_state, self.state = self.state, PoolState.PAUSED - self.gen.inc(service_id) newpid = os.getpid() if self.pid != newpid: self.pid = newpid @@ -848,19 +847,21 @@ async def _reset( self.active_sockets = 0 async with self._operation_count_lock: self.operation_count = 0 - async with self._conns_lock: - if service_id is None: + if service_id is None: + async with self._conns_lock: + self.gen.inc(service_id) sockets, self.conns = self.conns, collections.deque() if service_id is not None: discard: collections.deque = collections.deque() # type: ignore[type-arg] keep: collections.deque = collections.deque() # type: ignore[type-arg] - for conn in self.conns: - if conn.service_id == service_id: - discard.append(conn) - else: - keep.append(conn) - sockets = discard async with self._conns_lock: + self.gen.inc(service_id) + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) + else: + keep.append(conn) + sockets = discard self.conns = keep if close: @@ -932,8 +933,9 @@ async def update_is_writable(self, is_writable: Optional[bool]) -> None: Pool. """ self.is_writable = is_writable - for _socket in self.conns: - _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] + async with self._conns_lock: + for _socket in self.conns: + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] async def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -964,7 +966,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - async with self.lock: + async with self._conns_lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds @@ -1000,13 +1002,13 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: close_conn = False # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. - if self.gen.get_overall() != reference_generation: - close_conn = True - if not close_conn: - async with self._conns_lock: + async with self._conns_lock: + if self.gen.get_overall() != reference_generation: + close_conn = True + else: self.conns.appendleft(conn) - async with self._active_contexts_lock: - self.active_contexts.discard(conn.cancel_context) + async with self._active_contexts_lock: + self.active_contexts.discard(conn.cancel_context) if close_conn: await conn.close_conn(ConnectionClosedReason.STALE) return @@ -1410,13 +1412,14 @@ async def checkin(self, conn: AsyncConnection) -> None: ) else: close_conn = False - if self.stale_generation(conn.generation, conn.service_id): - close_conn = True - else: - conn.update_last_checkin_time() - conn.update_is_writable(bool(self.is_writable)) - async with self._conns_lock: + async with self._conns_lock: + if self.stale_generation(conn.generation, conn.service_id): + close_conn = True + else: + conn.update_last_checkin_time() + conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) + if not close_conn: async with self._max_connecting_cond: # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index cb469e4270..e2594b8ead 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -800,8 +800,8 @@ def __init__( def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. state_changed = False - if self.state != PoolState.READY: - with self.lock: + with self.lock: + if self.state != PoolState.READY: self.state = PoolState.READY state_changed = True if not state_changed: @@ -830,13 +830,12 @@ def _reset( interrupt_connections: bool = False, ) -> None: old_state = self.state - if self.closed: - return is_fork = False with self.lock: + if self.closed: + return if self.opts.pause_enabled and pause and not self.opts.load_balanced: old_state, self.state = self.state, PoolState.PAUSED - self.gen.inc(service_id) newpid = os.getpid() if self.pid != newpid: self.pid = newpid @@ -846,19 +845,21 @@ def _reset( self.active_sockets = 0 with self._operation_count_lock: self.operation_count = 0 - with self._conns_lock: - if service_id is None: + if service_id is None: + with self._conns_lock: + self.gen.inc(service_id) sockets, self.conns = self.conns, collections.deque() if service_id is not None: discard: collections.deque = collections.deque() # type: ignore[type-arg] keep: collections.deque = collections.deque() # type: ignore[type-arg] - for conn in self.conns: - if conn.service_id == service_id: - discard.append(conn) - else: - keep.append(conn) - sockets = discard with self._conns_lock: + self.gen.inc(service_id) + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) + else: + keep.append(conn) + sockets = discard self.conns = keep if close: @@ -930,8 +931,9 @@ def update_is_writable(self, is_writable: Optional[bool]) -> None: Pool. """ self.is_writable = is_writable - for _socket in self.conns: - _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] + with self._conns_lock: + for _socket in self.conns: + _socket.update_is_writable(self.is_writable) # type: ignore[arg-type] def reset( self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False @@ -960,7 +962,7 @@ def remove_stale_sockets(self, reference_generation: int) -> None: if self.opts.max_idle_time_seconds is not None: close_conns = [] - with self.lock: + with self._conns_lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds @@ -996,13 +998,13 @@ def remove_stale_sockets(self, reference_generation: int) -> None: close_conn = False # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. - if self.gen.get_overall() != reference_generation: - close_conn = True - if not close_conn: - with self._conns_lock: + with self._conns_lock: + if self.gen.get_overall() != reference_generation: + close_conn = True + else: self.conns.appendleft(conn) - with self._active_contexts_lock: - self.active_contexts.discard(conn.cancel_context) + with self._active_contexts_lock: + self.active_contexts.discard(conn.cancel_context) if close_conn: conn.close_conn(ConnectionClosedReason.STALE) return @@ -1406,13 +1408,14 @@ def checkin(self, conn: Connection) -> None: ) else: close_conn = False - if self.stale_generation(conn.generation, conn.service_id): - close_conn = True - else: - conn.update_last_checkin_time() - conn.update_is_writable(bool(self.is_writable)) - with self._conns_lock: + with self._conns_lock: + if self.stale_generation(conn.generation, conn.service_id): + close_conn = True + else: + conn.update_last_checkin_time() + conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) + if not close_conn: with self._max_connecting_cond: # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() From c73a21f62067b3e3b26cba6a5d6e3c48719f8981 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Tue, 19 May 2026 15:48:49 -0400 Subject: [PATCH 08/11] rename lock --- pymongo/asynchronous/pool.py | 16 ++++++++-------- pymongo/synchronous/pool.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 111e69682b..747e3e0f4e 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -709,10 +709,10 @@ def __init__( :param options: a PoolOptions instance :param is_sdam: whether to call hello for each new AsyncConnection """ - # Main lock only used to protect updating attributes. + # Lock only used to protect updating attributes like self.state. # Avoid any additional work while holding the lock. # If looping over an attribute, copy the container and do not take the lock. - self.lock = _async_create_lock() + self._state_lock = _async_create_lock() if options.pause_enabled: self.state = PoolState.PAUSED @@ -728,6 +728,8 @@ def __init__( # This lock should only be contended by threads adding/removing connections. self._conns_lock = _async_create_lock() + self.active_contexts: set[_CancellationContext] = set() + self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 # Track whether the sockets in this pool are writeable or not. @@ -787,8 +789,6 @@ def __init__( # This lock should be contended on every operation. self._operation_count_lock = _async_create_lock() - self.active_contexts: set[_CancellationContext] = set() - self.active_sockets = 0 # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). @@ -802,7 +802,7 @@ def __init__( async def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. state_changed = False - async with self.lock: + async with self._state_lock: if self.state != PoolState.READY: self.state = PoolState.READY state_changed = True @@ -833,7 +833,7 @@ async def _reset( ) -> None: old_state = self.state is_fork = False - async with self.lock: + async with self._state_lock: if self.closed: return if self.opts.pause_enabled and pause and not self.opts.load_balanced: @@ -865,7 +865,7 @@ async def _reset( self.conns = keep if close: - async with self.lock: + async with self._state_lock: self.state = PoolState.CLOSED # Clear the wait queue async with self._max_connecting_cond: @@ -960,7 +960,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: pool. """ # Take the lock to avoid the race condition described in PYTHON-2699. - async with self.lock: + async with self._state_lock: if self.state != PoolState.READY: return diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index e2594b8ead..c082ac97e9 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -707,10 +707,10 @@ def __init__( :param options: a PoolOptions instance :param is_sdam: whether to call hello for each new Connection """ - # Main lock only used to protect updating attributes. + # Lock only used to protect updating attributes like self.state. # Avoid any additional work while holding the lock. # If looping over an attribute, copy the container and do not take the lock. - self.lock = _create_lock() + self._state_lock = _create_lock() if options.pause_enabled: self.state = PoolState.PAUSED @@ -726,6 +726,8 @@ def __init__( # This lock should only be contended by threads adding/removing connections. self._conns_lock = _create_lock() + self.active_contexts: set[_CancellationContext] = set() + self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 # Track whether the sockets in this pool are writeable or not. @@ -785,8 +787,6 @@ def __init__( # This lock should be contended on every operation. self._operation_count_lock = _create_lock() - self.active_contexts: set[_CancellationContext] = set() - self.active_sockets = 0 # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). @@ -800,7 +800,7 @@ def __init__( def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. state_changed = False - with self.lock: + with self._state_lock: if self.state != PoolState.READY: self.state = PoolState.READY state_changed = True @@ -831,7 +831,7 @@ def _reset( ) -> None: old_state = self.state is_fork = False - with self.lock: + with self._state_lock: if self.closed: return if self.opts.pause_enabled and pause and not self.opts.load_balanced: @@ -863,7 +863,7 @@ def _reset( self.conns = keep if close: - with self.lock: + with self._state_lock: self.state = PoolState.CLOSED # Clear the wait queue with self._max_connecting_cond: @@ -956,7 +956,7 @@ def remove_stale_sockets(self, reference_generation: int) -> None: pool. """ # Take the lock to avoid the race condition described in PYTHON-2699. - with self.lock: + with self._state_lock: if self.state != PoolState.READY: return From 394dfa7cf948239d27aeab797349f70aade4cacc Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Tue, 19 May 2026 22:05:46 -0400 Subject: [PATCH 09/11] use new lock in tests --- test/asynchronous/unified_format.py | 2 +- test/unified_format.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 1fb93e7b86..ef88118afb 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -336,7 +336,7 @@ async def _create_entity(self, entity_spec, uri=None): while True: if (time.monotonic() - t0) > spec["awaitMinPoolSizeMS"] * 1000: raise ValueError("Test timed out during awaitMinPoolSize") - async with pool.lock: + async with pool._conns_lock: if len(pool.conns) + pool.active_sockets >= pool.opts.min_pool_size: break await asyncio.sleep(0.1) diff --git a/test/unified_format.py b/test/unified_format.py index 5516a7adf1..ec24959373 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -335,7 +335,7 @@ def _create_entity(self, entity_spec, uri=None): while True: if (time.monotonic() - t0) > spec["awaitMinPoolSizeMS"] * 1000: raise ValueError("Test timed out during awaitMinPoolSize") - with pool.lock: + with pool._conns_lock: if len(pool.conns) + pool.active_sockets >= pool.opts.min_pool_size: break time.sleep(0.1) From 1169c7323dcc590f44a9558da041d3b8a5aad2b7 Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Tue, 19 May 2026 22:50:29 -0400 Subject: [PATCH 10/11] update load balancer tests --- test/asynchronous/test_load_balancer.py | 2 +- test/test_load_balancer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 8c5d57434c..9a0919ea7f 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -168,7 +168,7 @@ def __init__(self, pool): self.unlock = create_async_event() async def lock_pool(self): - async with self.pool.lock: + async with self.pool._conns_lock: self.locked.set() # Wait for the unlock flag. unlock_pool = await self.wait(self.unlock, 10) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index de4b14e546..fe18bc505b 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -168,7 +168,7 @@ def __init__(self, pool): self.unlock = create_event() def lock_pool(self): - with self.pool.lock: + with self.pool._conns_lock: self.locked.set() # Wait for the unlock flag. unlock_pool = self.wait(self.unlock, 10) From 1cb560999476613ef3f86c20264878e9dbe1d19b Mon Sep 17 00:00:00 2001 From: Sophia Yang Date: Wed, 20 May 2026 16:50:55 -0400 Subject: [PATCH 11/11] small read benchmark --- test/performance/async_perf_test.py | 56 ++++++++++++++++++++++- test/performance/perf_test.py | 71 +++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) diff --git a/test/performance/async_perf_test.py b/test/performance/async_perf_test.py index 6eb31ea4fe..961a2ddac2 100644 --- a/test/performance/async_perf_test.py +++ b/test/performance/async_perf_test.py @@ -56,11 +56,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest +from test.asynchronous import AsyncPyMongoTestCase, async_client_context, get_loop, unittest from bson import encode from gridfs import AsyncGridFSBucket from pymongo import ( + AsyncMongoClient, DeleteOne, InsertOne, ReplaceOne, @@ -98,6 +99,9 @@ def tearDownModule(): else: print(output) + if getattr(async_client_context, "client", None): + get_loop().run_until_complete(async_client_context.client.close()) + class Timer: def __enter__(self): @@ -297,6 +301,56 @@ async def do_task(self): await asyncio.gather(*[find_one({"_id": _id}) for _id in self.inserted_ids]) +class SmallReadTest(PerformanceTest): + dataset = "small_doc.json" + n_tasks = 1 + + async def asyncSetUp(self): + await super().asyncSetUp() + with open( # noqa: ASYNC101 + os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)) + ) as data: + self.document = json.loads(data.read()) + + self.data_size = len(encode(self.document)) * NUM_DOCS + + client_options = dict(async_client_context.client_options) + client_options["minPoolSize"] = self.n_tasks + self.client = AsyncMongoClient(**client_options) + + await self.client.drop_database("perftest") + self.corpus = self.client.perftest.corpus + result = await self.corpus.insert_many([self.document.copy() for _ in range(NUM_DOCS)]) + self.inserted_ids = result.inserted_ids + + await self.client.admin.command("ping") + await self.corpus.find_one({"_id": self.inserted_ids[0]}) + + async def asyncTearDown(self): + try: + await super().asyncTearDown() + await self.client.drop_database("perftest") + finally: + await self.client.close() + + async def before(self): + pass + + async def after(self): + pass + + +class TestSmallReadFindOneByID(SmallReadTest, AsyncPyMongoTestCase): + async def do_task(self): + find_one = self.corpus.find_one + for _id in self.inserted_ids: + await find_one({"_id": _id}) + + +class TestSmallReadFindOneByID8Tasks(TestSmallReadFindOneByID): + n_tasks = 8 + + class SmallDocInsertTest(TestDocument): dataset = "small_doc.json" diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 5688d28d2d..9d4c2400d5 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -100,6 +100,9 @@ def tearDownModule(): else: print(output) + if getattr(client_context, "client", None): + client_context.client.close() + class Timer: def __enter__(self): @@ -394,6 +397,74 @@ class TestFindOneByID8Threads(TestFindOneByID): n_threads = 8 +class SmallReadTest(PerformanceTest): + dataset = "small_doc.json" + n_threads = 1 + + def setUp(self): + super().setUp() + with open( + os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)) + ) as data: + self.document = json.loads(data.read()) + + self.data_size = len(encode(self.document)) * NUM_DOCS + + client_options = dict(client_context.client_options) + client_options["minPoolSize"] = self.n_threads + self.client = MongoClient(**client_options) + + self.client.drop_database("perftest") + self.corpus = self.client.perftest.corpus + result = self.corpus.insert_many([self.document.copy() for _ in range(NUM_DOCS)]) + self.inserted_ids = result.inserted_ids + + self.client.admin.command("ping") + self.corpus.find_one({"_id": self.inserted_ids[0]}) + + def tearDown(self): + try: + super().tearDown() + self.client.drop_database("perftest") + finally: + self.client.close() + + def before(self): + pass + + def after(self): + pass + + +class TestSmallReadFindOneByID(SmallReadTest, unittest.TestCase): + def do_task(self): + find_one = self.corpus.find_one + for _id in self.inserted_ids: + find_one({"_id": _id}) + + +class TestSmallReadFindOneByID8Threads(TestSmallReadFindOneByID): + n_threads = 8 + + +class TestSmallReadFindOneByID8ThreadsPerClient(SmallReadTest, unittest.TestCase): + n_threads = 8 + + def do_task(self): + client_options = dict(client_context.client_options) + client = MongoClient(**client_options) + try: + corpus = client.perftest.corpus + client.admin.command("ping") + corpus.find_one({"_id": self.inserted_ids[0]}) + + find_one = corpus.find_one + for _id in self.inserted_ids: + find_one({"_id": _id}) + finally: + client.close() + + class SmallDocInsertTest(TestDocument): dataset = "small_doc.json"