diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 32ed5507..7854b1f2 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -51,7 +51,13 @@ Connection::Connection(const std::wstring& conn_str, bool use_pool) } Connection::~Connection() { - disconnect(); // fallback if user forgets to disconnect + try { + disconnect(); // fallback if user forgets to disconnect + } catch (...) { + // Never throw from a destructor — doing so during stack unwinding + // causes std::terminate(). Log and swallow. + LOG_ERROR("Exception suppressed in ~Connection destructor"); + } } // Allocates connection handle @@ -99,23 +105,22 @@ void Connection::disconnect() { // When we free the DBC handle below, the ODBC driver will automatically free // all child STMT handles. We need to tell the SqlHandle objects about this // so they don't try to free the handles again during their destruction. - + // THREAD-SAFETY: Lock mutex to safely access _childStatementHandles // This protects against concurrent allocStatementHandle() calls or GC finalizers { std::lock_guard lock(_childHandlesMutex); - + // First compact: remove expired weak_ptrs (they're already destroyed) size_t originalSize = _childStatementHandles.size(); _childStatementHandles.erase( std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(), [](const std::weak_ptr& wp) { return wp.expired(); }), _childStatementHandles.end()); - - LOG("Compacted child handles: %zu -> %zu (removed %zu expired)", - originalSize, _childStatementHandles.size(), - originalSize - _childStatementHandles.size()); - + + LOG("Compacted child handles: %zu -> %zu (removed %zu expired)", originalSize, + _childStatementHandles.size(), originalSize - _childStatementHandles.size()); + LOG("Marking %zu child statement handles as implicitly freed", _childStatementHandles.size()); for (auto& weakHandle : _childStatementHandles) { @@ -124,8 +129,10 @@ void Connection::disconnect() { // This is guaranteed by allocStatementHandle() which only creates STMT handles // If this assertion fails, it indicates a serious bug in handle tracking if (handle->type() != SQL_HANDLE_STMT) { - LOG_ERROR("CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. " - "This will cause a handle leak!", handle->type()); + LOG_ERROR( + "CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. " + "This will cause a handle leak!", + handle->type()); continue; // Skip marking to prevent leak } handle->markImplicitlyFreed(); @@ -136,8 +143,24 @@ void Connection::disconnect() { } // Release lock before potentially slow SQLDisconnect call SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); - checkError(ret); - // triggers SQLFreeHandle via destructor, if last owner + if (!SQL_SUCCEEDED(ret)) { + // Log the error but do NOT throw — disconnect must be safe to call + // from destructors, reset() failure paths, and pool cleanup. + // Throwing here during stack unwinding causes std::terminate(). + LOG_ERROR("SQLDisconnect failed (ret=%d), forcing handle cleanup", ret); + + // Best-effort: retrieve and log ODBC diagnostics for debuggability. + // This must not throw, to keep disconnect noexcept-safe. + try { + ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret); + std::string diagMsg = WideToUTF8(err.ddbcErrorMsg); + LOG_ERROR("SQLDisconnect diagnostics: %s", diagMsg.c_str()); + } catch (...) { + // Swallow all exceptions: cleanup paths must not throw. + LOG_ERROR("SQLDisconnect: failed to retrieve ODBC diagnostics"); + } + } + // Always free the handle regardless of SQLDisconnect result _dbcHandle.reset(); } else { LOG("No connection handle to disconnect"); @@ -221,7 +244,7 @@ SqlHandlePtr Connection::allocStatementHandle() { // or GC finalizers running from different threads { std::lock_guard lock(_childHandlesMutex); - + // Track this child handle so we can mark it as implicitly freed when connection closes // Use weak_ptr to avoid circular references and allow normal cleanup _childStatementHandles.push_back(stmtHandle); @@ -237,9 +260,8 @@ SqlHandlePtr Connection::allocStatementHandle() { [](const std::weak_ptr& wp) { return wp.expired(); }), _childStatementHandles.end()); _allocationsSinceCompaction = 0; - LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)", - originalSize, _childStatementHandles.size(), - originalSize - _childStatementHandles.size()); + LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)", originalSize, + _childStatementHandles.size(), originalSize - _childStatementHandles.size()); } } // Release lock diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp index 3000a970..3cb30fc7 100644 --- a/mssql_python/pybind/connection/connection_pool.cpp +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -21,20 +21,24 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, auto now = std::chrono::steady_clock::now(); size_t before = _pool.size(); + LOG("ConnectionPool::acquire: pool_size=%zu, max_size=%zu, idle_timeout=%d", before, + _max_size, _idle_timeout_secs); + // Phase 1: Remove stale connections, collect for later disconnect - _pool.erase(std::remove_if(_pool.begin(), _pool.end(), - [&](const std::shared_ptr& conn) { - auto idle_time = - std::chrono::duration_cast( - now - conn->lastUsed()) - .count(); - if (idle_time > _idle_timeout_secs) { - to_disconnect.push_back(conn); - return true; - } - return false; - }), - _pool.end()); + _pool.erase( + std::remove_if( + _pool.begin(), _pool.end(), + [&](const std::shared_ptr& conn) { + auto idle_time = + std::chrono::duration_cast(now - conn->lastUsed()) + .count(); + if (idle_time > _idle_timeout_secs) { + to_disconnect.push_back(conn); + return true; + } + return false; + }), + _pool.end()); size_t pruned = before - _pool.size(); _current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 51fb4965..31f2eca5 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4514,10 +4514,8 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) { return rowCount; } -static std::once_flag pooling_init_flag; void enable_pooling(int maxSize, int idleTimeout) { - std::call_once(pooling_init_flag, - [&]() { ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout); }); + ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout); } // Thread-safe decimal separator setting diff --git a/tests/test_009_pooling.py b/tests/test_009_pooling.py index 1a3e5f09..9647cd2b 100644 --- a/tests/test_009_pooling.py +++ b/tests/test_009_pooling.py @@ -104,20 +104,16 @@ def test_connection_pooling_isolation_level_reset(conn_str): # Set isolation level to SERIALIZABLE (non-default) conn1.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_SERIALIZABLE) - # Verify the isolation level was set + # Verify the isolation level was set (use DBCC USEROPTIONS to avoid + # requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions) cursor1 = conn1.cursor() - cursor1.execute( - "SELECT CASE transaction_isolation_level " - "WHEN 0 THEN 'Unspecified' " - "WHEN 1 THEN 'ReadUncommitted' " - "WHEN 2 THEN 'ReadCommitted' " - "WHEN 3 THEN 'RepeatableRead' " - "WHEN 4 THEN 'Serializable' " - "WHEN 5 THEN 'Snapshot' END AS isolation_level " - "FROM sys.dm_exec_sessions WHERE session_id = @@SPID" - ) - isolation_level_1 = cursor1.fetchone()[0] - assert isolation_level_1 == "Serializable", f"Expected Serializable, got {isolation_level_1}" + cursor1.execute("DBCC USEROPTIONS WITH NO_INFOMSGS") + isolation_level_1 = None + for row in cursor1.fetchall(): + if row[0] == "isolation level": + isolation_level_1 = row[1] + break + assert isolation_level_1 == "serializable", f"Expected serializable, got {isolation_level_1}" # Get SPID for verification of connection reuse cursor1.execute("SELECT @@SPID") @@ -138,24 +134,20 @@ def test_connection_pooling_isolation_level_reset(conn_str): # Verify connection was reused assert spid1 == spid2, "Connection was not reused from pool" - # Check if isolation level is reset to default - cursor2.execute( - "SELECT CASE transaction_isolation_level " - "WHEN 0 THEN 'Unspecified' " - "WHEN 1 THEN 'ReadUncommitted' " - "WHEN 2 THEN 'ReadCommitted' " - "WHEN 3 THEN 'RepeatableRead' " - "WHEN 4 THEN 'Serializable' " - "WHEN 5 THEN 'Snapshot' END AS isolation_level " - "FROM sys.dm_exec_sessions WHERE session_id = @@SPID" - ) - isolation_level_2 = cursor2.fetchone()[0] + # Check if isolation level is reset to default (use DBCC USEROPTIONS to avoid + # requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions) + cursor2.execute("DBCC USEROPTIONS WITH NO_INFOMSGS") + isolation_level_2 = None + for row in cursor2.fetchall(): + if row[0] == "isolation level": + isolation_level_2 = row[1] + break # Verify isolation level is reset to default (READ COMMITTED) # This is the CORRECT behavior for connection pooling - we should reset # session state to prevent settings from one usage affecting the next - assert isolation_level_2 == "ReadCommitted", ( - f"Isolation level was not reset! Expected 'ReadCommitted', got '{isolation_level_2}'. " + assert isolation_level_2 == "read committed", ( + f"Isolation level was not reset! Expected 'read committed', got '{isolation_level_2}'. " f"This indicates session state leaked from the previous connection usage." ) @@ -278,30 +270,28 @@ def try_overflow(): c.close() -@pytest.mark.skip("Flaky test - idle timeout behavior needs investigation") def test_pool_idle_timeout_removes_connections(conn_str): """Test that idle_timeout removes connections from the pool after the timeout.""" pooling(max_size=2, idle_timeout=1) conn1 = connect(conn_str) - spid_list = [] cursor1 = conn1.cursor() + # Use @@SPID to identify the connection without requiring + # VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections. cursor1.execute("SELECT @@SPID") spid1 = cursor1.fetchone()[0] - spid_list.append(spid1) conn1.close() - # Wait for longer than idle_timeout - time.sleep(3) + # Wait well beyond the idle_timeout to account for slow CI and integer-second granularity + time.sleep(5) - # Get a new connection, which should not reuse the previous SPID + # Get a new connection — the idle one should have been evicted during acquire() conn2 = connect(conn_str) cursor2 = conn2.cursor() cursor2.execute("SELECT @@SPID") spid2 = cursor2.fetchone()[0] - spid_list.append(spid2) conn2.close() - assert spid1 != spid2, "Idle timeout did not remove connection from pool" + assert spid1 != spid2, "Idle timeout did not remove connection from pool — same SPID reused" # ============================================================================= @@ -309,51 +299,63 @@ def test_pool_idle_timeout_removes_connections(conn_str): # ============================================================================= -@pytest.mark.skip( - "Test causes fatal crash - forcibly closing underlying connection leads to undefined behavior" -) def test_pool_removes_invalid_connections(conn_str): - """Test that the pool removes connections that become invalid (simulate by closing underlying connection).""" + """Test that the pool removes connections that become invalid and recovers gracefully. + + This test simulates a connection being returned to the pool in a dirty state + (with an open transaction) by calling _conn.close() directly, bypassing the + normal Python close() which does a rollback. The pool's acquire() should detect + the bad connection during reset(), discard it, and create a fresh one. + """ pooling(max_size=1, idle_timeout=30) conn = connect(conn_str) cursor = conn.cursor() cursor.execute("SELECT 1") - # Simulate invalidation by forcibly closing the connection at the driver level - try: - # Try to access a private attribute or method to forcibly close the underlying connection - # This is implementation-specific; if not possible, skip - if hasattr(conn, "_conn") and hasattr(conn._conn, "close"): - conn._conn.close() - else: - pytest.skip("Cannot forcibly close underlying connection for this driver") - except Exception: - pass - # Safely close the connection, ignoring errors due to forced invalidation + cursor.fetchone() + + # Record the SPID of the original connection (avoids requiring + # VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections) + cursor.execute("SELECT @@SPID") + original_spid = cursor.fetchone()[0] + + # Force-return the connection to the pool WITHOUT rollback. + # This leaves the pooled connection in a dirty state (open implicit transaction) + # which will cause reset() to fail on next acquire(). + conn._conn.close() + + # Python close() will fail since the underlying handle is already gone try: conn.close() - except RuntimeError as e: - if "not initialized" not in str(e): - raise - # Now, get a new connection from the pool and ensure it works + except RuntimeError: + pass + + # Now get a new connection — the pool should discard the dirty one and create fresh new_conn = connect(conn_str) new_cursor = new_conn.cursor() - try: - new_cursor.execute("SELECT 1") - result = new_cursor.fetchone() - assert result is not None and result[0] == 1, "Pool did not remove invalid connection" - finally: - new_conn.close() + new_cursor.execute("SELECT 1") + result = new_cursor.fetchone() + assert result is not None and result[0] == 1, "Pool did not recover from invalid connection" + + # Verify it's a different physical connection + new_cursor.execute("SELECT @@SPID") + new_spid = new_cursor.fetchone()[0] + assert ( + original_spid != new_spid + ), "Expected a new physical connection after pool discarded the dirty one" + + new_conn.close() def test_pool_recovery_after_failed_connection(conn_str): """Test that the pool recovers after a failed connection attempt.""" pooling(max_size=1, idle_timeout=30) # First, try to connect with a bad password (should fail) - if "Pwd=" in conn_str: - bad_conn_str = conn_str.replace("Pwd=", "Pwd=wrongpassword") - elif "Password=" in conn_str: - bad_conn_str = conn_str.replace("Password=", "Password=wrongpassword") - else: + import re + + # Replace the value of the first Pwd/Password key-value pair with "wrongpassword" + pattern = re.compile(r"(?i)(Pwd|Password\s*=\s*)([^;]*)") + bad_conn_str, num_subs = pattern.subn(lambda m: m.group(1) + "wrongpassword", conn_str, count=1) + if num_subs == 0: pytest.skip("No password found in connection string to modify") with pytest.raises(Exception): connect(bad_conn_str)