From f854be5c5299bf47f0ac975f2ba37764295945c1 Mon Sep 17 00:00:00 2001 From: Andrew de Waal Date: Mon, 15 Jun 2026 08:48:47 -0700 Subject: [PATCH] refactor: support closing all connections to a db in Rust Previously, the only way to close the connection to a database was through a command invocation, which is only available through IPC. We need the ability to close any open connections to a db in Rust. For example if we are performing file operations (such as deleting and recreating the db) from the Rust layer, we need to first ensure all connections are closed before taking any further action. This simple refactor breaks the logic to close any open connections into a public function callable from Rust. --- CHANGELOG.md | 14 + crates/sqlx-sqlite-toolkit/src/lib.rs | 2 +- .../sqlx-sqlite-toolkit/src/transactions.rs | 193 +++++- .../tests/transaction_state_tests.rs | 85 ++- src/commands.rs | 92 ++- src/error.rs | 11 + src/lib.rs | 562 +++++++++++++++++- 7 files changed, 840 insertions(+), 119 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 330fe9c..1cae6c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,20 @@ Databases must be registered on the Rust side with a stable key before they can - Parent directory auto-creation during registration validation for file paths. - CI check that committed `api-iife.js` matches a fresh Rollup build. +### Changed + +#### `close` aborts active transactions before closing + +`Database.close()` (IPC `plugin:sqlite|close`), `Database.close_all()` (IPC `plugin:sqlite|close_all`), and the Rust [`Connection::close`](src/lib.rs) API now roll back or cancel in-flight transactions before closing connection pools. + +- **Interruptible transactions** (`beginInterruptibleTransaction`): explicitly rolled back via `ROLLBACK` before the pool is closed. +- **Regular transactions** (`executeTransaction`): the in-flight task is aborted and awaited so pooled connections are released before the pool closes. + +Previously, `close` only aborted active subscriptions; open transactions could block a clean shutdown or leave uncommitted work on pooled connections. The same cleanup logic is available to Rust callers through [`close_database`](src/lib.rs) and [`close_all_loaded_databases`](src/lib.rs). + +Transaction cleanup failures propagate as errors rather than being logged and ignored, so a successful close indicates the database file is safe to delete or recreate. + ### Fixed +- `load()` no longer hangs forever for databases registered without a migrator. Migration state is only tracked when a migrator is provided; otherwise `await_migrations` proceeds immediately. - Regenerated `api-iife.js` so all IPC calls use `dbKey` consistently. diff --git a/crates/sqlx-sqlite-toolkit/src/lib.rs b/crates/sqlx-sqlite-toolkit/src/lib.rs index 201ba2d..627801c 100644 --- a/crates/sqlx-sqlite-toolkit/src/lib.rs +++ b/crates/sqlx-sqlite-toolkit/src/lib.rs @@ -46,7 +46,7 @@ pub use error::{Error, Result}; pub use pagination::{KeysetColumn, KeysetPage, SortDirection}; pub use transactions::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, - Statement, TransactionWriter, cleanup_all_transactions, + Statement, TransactionWriter, cleanup_all_transactions, cleanup_transactions_for_db, }; pub use wrapper::{ DatabaseWrapper, InterruptibleTransaction, InterruptibleTransactionBuilder, diff --git a/crates/sqlx-sqlite-toolkit/src/transactions.rs b/crates/sqlx-sqlite-toolkit/src/transactions.rs index 559d9eb..db78056 100644 --- a/crates/sqlx-sqlite-toolkit/src/transactions.rs +++ b/crates/sqlx-sqlite-toolkit/src/transactions.rs @@ -10,7 +10,7 @@ use serde_json::Value as JsonValue; use sqlx::{Column, Row}; use sqlx_sqlite_conn_mgr::{AttachedWriteGuard, WriteGuard}; use tokio::sync::{Mutex, RwLock}; -use tokio::task::AbortHandle; +use tokio::task::JoinHandle; use tracing::{debug, warn}; #[cfg(feature = "observer")] @@ -106,6 +106,8 @@ pub struct ActiveInterruptibleTransaction { // call sqlx's rt::spawn and panic with "this functionality requires a Tokio // context". runtime_handle: tokio::runtime::Handle, + #[cfg(test)] + force_rollback_failure: bool, } impl ActiveInterruptibleTransaction { @@ -121,9 +123,18 @@ impl ActiveInterruptibleTransaction { writer: Some(writer), created_at: Instant::now(), runtime_handle: tokio::runtime::Handle::current(), + #[cfg(test)] + force_rollback_failure: false, } } + /// Test-only: force `rollback()` to fail when aborting via transaction state. + #[cfg(test)] + pub fn force_rollback_failure_for_test(mut self) -> Self { + self.force_rollback_failure = true; + self + } + fn writer_mut(&mut self) -> Result<&mut TransactionWriter> { self .writer @@ -208,6 +219,15 @@ impl ActiveInterruptibleTransaction { /// Rollback this transaction pub async fn rollback(mut self) -> Result<()> { + #[cfg(test)] + if self.force_rollback_failure { + let db_path = self.db_path.clone(); + drop(self.take_writer()?); + return Err(Error::Other(format!( + "forced rollback failure for test (db: {db_path})" + ))); + } + let mut writer = self.take_writer()?; writer.rollback().await?; @@ -369,7 +389,7 @@ impl ActiveInterruptibleTransactions { } } - pub async fn abort_all(&self) { + pub async fn abort_all(&self) -> Result<()> { // Drain under the lock, then release it before awaiting rollbacks so we // don't hold the mutex across a chain of awaits. let drained: Vec<(String, ActiveInterruptibleTransaction)> = { @@ -383,10 +403,28 @@ impl ActiveInterruptibleTransactions { "Rolling back interruptible transaction for database: {}", db_path ); - if let Err(err) = tx.rollback().await { - warn!("rollback during abort_all failed (db: {db_path}): {err}"); - } + tx.rollback().await?; } + + Ok(()) + } + + /// Roll back and remove the interruptible transaction for a single database, if any. + pub async fn abort_for_db(&self, db_key: &str) -> Result<()> { + let maybe_tx = { + let mut txs = self.inner.lock().await; + txs.remove(db_key) + }; + + if let Some(tx) = maybe_tx { + debug!( + "Rolling back interruptible transaction for database: {}", + db_key + ); + tx.rollback().await?; + } + + Ok(()) } /// Remove and return transaction for commit/rollback. @@ -435,14 +473,32 @@ impl ActiveInterruptibleTransactions { /// Tracking for regular (non-pausable) transactions that are in-flight. /// -/// Holds abort handles so transactions can be cancelled on app exit. +/// Holds join handles so transactions can be cancelled and awaited on close. #[derive(Clone, Default)] -pub struct ActiveRegularTransactions(Arc>>); +pub struct ActiveRegularTransactions(Arc>>>); + +async fn abort_and_await_regular_handles(handles: Vec<(String, JoinHandle<()>)>) -> Result<()> { + for (key, handle) in handles { + debug!("Aborting regular transaction: {}", key); + handle.abort(); + match handle.await { + Ok(()) => {} + Err(e) if e.is_cancelled() => {} + Err(e) => { + return Err(Error::Other(format!( + "regular transaction task panicked: {e}" + ))); + } + } + } + + Ok(()) +} impl ActiveRegularTransactions { - pub async fn insert(&self, key: String, abort_handle: AbortHandle) { + pub async fn insert(&self, key: String, handle: JoinHandle<()>) { let mut txs = self.0.write().await; - txs.insert(key, abort_handle); + txs.insert(key, handle); } pub async fn remove(&self, key: &str) { @@ -450,16 +506,25 @@ impl ActiveRegularTransactions { txs.remove(key); } - pub async fn abort_all(&self) { - let mut txs = self.0.write().await; - debug!("Aborting {} active regular transaction(s)", txs.len()); + pub async fn abort_all(&self) -> Result<()> { + let handles: Vec<(String, JoinHandle<()>)> = { + let mut txs = self.0.write().await; + debug!("Aborting {} active regular transaction(s)", txs.len()); + txs.drain().collect() + }; - for (key, abort_handle) in txs.iter() { - debug!("Aborting regular transaction: {}", key); - abort_handle.abort(); - } + abort_and_await_regular_handles(handles).await + } - txs.clear(); + /// Abort in-flight regular transactions for a single database. + pub async fn abort_for_db(&self, db_key: &str) -> Result<()> { + let prefix = format!("{db_key}:"); + let handles: Vec<(String, JoinHandle<()>)> = { + let mut txs = self.0.write().await; + txs.extract_if(|key, _| key.starts_with(&prefix)).collect() + }; + + abort_and_await_regular_handles(handles).await } } @@ -467,11 +532,97 @@ impl ActiveRegularTransactions { pub async fn cleanup_all_transactions( interruptible: &ActiveInterruptibleTransactions, regular: &ActiveRegularTransactions, -) { +) -> Result<()> { debug!("Cleaning up all active transactions"); - interruptible.abort_all().await; - regular.abort_all().await; + interruptible.abort_all().await?; + regular.abort_all().await?; - debug!("Transaction cleanup initiated"); + debug!("Transaction cleanup complete"); + Ok(()) +} + +pub async fn cleanup_transactions_for_db( + db_key: &str, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + interruptible_txs.abort_for_db(db_key).await?; + regular_txs.abort_for_db(db_key).await?; + Ok(()) +} + +#[cfg(test)] +mod abort_error_tests { + use super::*; + use crate::DatabaseWrapper; + use serde_json::json; + + async fn begin_test_transaction( + db: &DatabaseWrapper, + db_path: &str, + ) -> ActiveInterruptibleTransaction { + let guard = db.acquire_writer().await.unwrap(); + let mut writer = TransactionWriter::from(guard); + writer.begin_immediate().await.unwrap(); + ActiveInterruptibleTransaction::new( + db_path.to_string(), + uuid::Uuid::new_v4().to_string(), + writer, + ) + } + + #[tokio::test] + async fn test_abort_for_db_propagates_rollback_failure() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("fail.db"); + let db = DatabaseWrapper::connect(&db_path, None).await.unwrap(); + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(), + vec![], + ) + .await + .unwrap(); + + let state = ActiveInterruptibleTransactions::default(); + let mut tx = begin_test_transaction(&db, "fail.db").await; + tx.continue_with(vec![( + "INSERT INTO t (val) VALUES (?)", + vec![json!("uncommitted")], + )]) + .await + .unwrap(); + let tx = tx.force_rollback_failure_for_test(); + state.insert("fail.db".into(), tx).await.unwrap(); + + let err = state.abort_for_db("fail.db").await.unwrap_err(); + assert!(err.to_string().contains("forced rollback failure")); + } + + #[tokio::test] + async fn test_abort_all_propagates_rollback_failure() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("fail-all.db"); + let db = DatabaseWrapper::connect(&db_path, None).await.unwrap(); + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(), + vec![], + ) + .await + .unwrap(); + + let state = ActiveInterruptibleTransactions::default(); + let mut tx = begin_test_transaction(&db, "fail-all.db").await; + tx.continue_with(vec![( + "INSERT INTO t (val) VALUES (?)", + vec![json!("uncommitted")], + )]) + .await + .unwrap(); + let tx = tx.force_rollback_failure_for_test(); + state.insert("fail-all.db".into(), tx).await.unwrap(); + + let err = state.abort_all().await.unwrap_err(); + assert!(err.to_string().contains("forced rollback failure")); + } } diff --git a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs index 7ac675d..52023b6 100644 --- a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs +++ b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs @@ -134,13 +134,41 @@ async fn test_abort_all_clears_transactions() { let tx_id = tx.transaction_id().to_string(); state.insert("abort.db".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // After abort_all, remove should fail (transaction was cleared) let err = expect_err(state.remove("abort.db", &tx_id).await); assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); } +#[tokio::test] +async fn test_abort_for_db_clears_only_matching_interruptible() { + let (db1, _temp1) = create_test_db("main.db").await; + let (db2, _temp2) = create_test_db("other.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + let state = ActiveInterruptibleTransactions::default(); + let main_tx = begin_transaction(&db1, "main").await; + let main_tx_id = main_tx.transaction_id().to_string(); + let other_tx = begin_transaction(&db2, "other").await; + let other_tx_id = other_tx.transaction_id().to_string(); + + state.insert("main".into(), main_tx).await.unwrap(); + state.insert("other".into(), other_tx).await.unwrap(); + + state.abort_for_db("main").await.unwrap(); + + let err = expect_err(state.remove("main", &main_tx_id).await); + assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); + + assert!(state.remove("other", &other_tx_id).await.is_ok()); +} + #[tokio::test] async fn test_abort_all_auto_rollbacks_uncommitted_writes() { let (db, _temp) = create_test_db("rollback.db").await; @@ -165,7 +193,7 @@ async fn test_abort_all_auto_rollbacks_uncommitted_writes() { // Store and abort (should auto-rollback on drop) state.insert("rollback.db".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // The uncommitted write should not be visible let rows = db @@ -195,7 +223,7 @@ async fn test_insert_after_abort_all_succeeds() { let tx = begin_transaction(&db1, "reuse-key").await; state.insert("reuse-key".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // Should be able to insert again after abort let tx2 = begin_transaction(&db2, "reuse-key").await; @@ -285,7 +313,7 @@ async fn test_regular_insert_and_remove() { let state = ActiveRegularTransactions::default(); let handle = tokio::spawn(async { /* no-op */ }); - state.insert("tx-1".into(), handle.abort_handle()).await; + state.insert("tx-1".into(), handle).await; // Remove should succeed (no panic, no error) state.remove("tx-1").await; @@ -298,18 +326,12 @@ async fn test_regular_insert_and_remove() { async fn test_regular_abort_all_cancels_tasks() { let state = ActiveRegularTransactions::default(); - // Spawn a long-running task let handle = tokio::spawn(async { tokio::time::sleep(std::time::Duration::from_secs(60)).await; }); - let abort_handle = handle.abort_handle(); - state.insert("long-task".into(), abort_handle).await; - state.abort_all().await; - - // The task should have been aborted - let result = handle.await; - assert!(result.unwrap_err().is_cancelled()); + state.insert("long-task".into(), handle).await; + state.abort_all().await.unwrap(); } #[tokio::test] @@ -319,14 +341,34 @@ async fn test_regular_abort_all_clears_state() { let h1 = tokio::spawn(async {}); let h2 = tokio::spawn(async {}); - state.insert("a".into(), h1.abort_handle()).await; - state.insert("b".into(), h2.abort_handle()).await; + state.insert("a".into(), h1).await; + state.insert("b".into(), h2).await; - state.abort_all().await; + state.abort_all().await.unwrap(); // State should be empty — inserting new keys should work let h3 = tokio::spawn(async {}); - state.insert("a".into(), h3.abort_handle()).await; + state.insert("a".into(), h3).await; +} + +#[tokio::test] +async fn test_regular_abort_for_db_only_matching_prefix() { + let state = ActiveRegularTransactions::default(); + + let main_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + let other_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + state.insert("main:one".into(), main_handle).await; + state.insert("other:two".into(), other_handle).await; + + state.abort_for_db("main").await.unwrap(); + + // Other database transaction should still be tracked and abortable. + state.abort_for_db("other").await.unwrap(); } // ============================================================================ @@ -352,17 +394,14 @@ async fn test_cleanup_all_transactions() { let handle = tokio::spawn(async { tokio::time::sleep(std::time::Duration::from_secs(60)).await; }); - regular - .insert("regular-1".into(), handle.abort_handle()) - .await; + regular.insert("regular-1".into(), handle).await; // Cleanup should clear both - cleanup_all_transactions(&interruptible, ®ular).await; + cleanup_all_transactions(&interruptible, ®ular) + .await + .unwrap(); // Interruptible should be empty let err = expect_err(interruptible.remove("cleanup.db", "any").await); assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); - - // Regular task should be cancelled - assert!(handle.await.unwrap_err().is_cancelled()); } diff --git a/src/commands.rs b/src/commands.rs index 57f5e9a..54ee596 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -18,13 +18,13 @@ use tauri::{AppHandle, Runtime, State}; use tracing::debug; use uuid::Uuid; -use crate::connect_to_database; use crate::{ DbInstances, Error, MigrationEvent, MigrationStates, Result, subscriptions::{ ActiveSubscriptions, ObserverConfigParams, TableChangePayload, event_to_payload, }, }; +use crate::{close_all_loaded_databases, close_database, connect_to_database}; /// Token representing an active interruptible transaction #[derive(Debug, Clone, Serialize, Deserialize)] @@ -174,10 +174,11 @@ pub async fn execute_transaction( None }; - // Spawn transaction execution with abort handle for cleanup on exit + // Spawn transaction execution with join handle tracked for cleanup on close let wrapper_clone = wrapper.clone(); let tx_key_clone = tx_key.clone(); let regular_txs_clone = regular_txs.inner().clone(); + let (result_tx, result_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn(async move { // Convert String to &str for execute_transaction @@ -196,27 +197,18 @@ pub async fn execute_transaction( // Remove from tracking when complete (even if result is Err) regular_txs_clone.remove(&tx_key_clone).await; - - result + let _ = result_tx.send(result); }); - // Track abort handle for cleanup on app exit - regular_txs - .insert(tx_key.clone(), handle.abort_handle()) - .await; + regular_txs.insert(tx_key.clone(), handle).await; - // Wait for transaction to complete - match handle.await { + match result_rx.await { Ok(result) => Ok(result?), - Err(e) => { - // Task panicked or was aborted - ensure cleanup + Err(_) => { regular_txs.remove(&tx_key).await; - - if e.is_cancelled() { - Err(Error::Other("Transaction aborted due to app exit".into())) - } else { - Err(Error::Other(format!("Transaction task panicked: {}", e))) - } + Err(Error::Other( + "Transaction task dropped before completion".into(), + )) } } } @@ -322,57 +314,53 @@ pub async fn fetch_page( Ok(result) } -/// Close a specific database connection +/// Close the loaded instance for a registered database key. /// /// Returns `true` if the database was loaded and successfully closed. /// Returns `false` if the database was not loaded (nothing to close). -/// Any active subscriptions for this database are aborted before closing. +/// Returns `Err` if transaction cleanup or pool close fails (database file +/// may not be safe to delete or recreate). +/// Active subscriptions for this key are aborted, and in-flight transactions +/// are cleaned up (interruptible transactions rolled back; regular transaction +/// tasks aborted and awaited) before the connection pool is closed. #[tauri::command] pub async fn close( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, db_key: String, ) -> Result { - active_subs.remove_for_db(&db_key).await; - - let mut instances = db_instances.inner.write().await; - - if let Some(wrapper) = instances.remove(&db_key) { - wrapper.close().await?; - Ok(true) - } else { - Ok(false) // Database wasn't loaded - } + close_database( + &db_key, + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } -/// Close all database connections +/// Close all database connections. /// -/// All active subscriptions are aborted before closing. Each wrapper's -/// `close()` handles disabling its own observer at the crate level. +/// All active subscriptions are aborted and in-flight transactions are cleaned +/// up (interruptible transactions rolled back; regular transaction tasks +/// aborted and awaited) before connection pools are closed. +/// Returns `Err` if transaction cleanup or any pool close fails. #[tauri::command] pub async fn close_all( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, ) -> Result<()> { - active_subs.abort_all().await; - - let mut instances = db_instances.inner.write().await; - - // Collect all wrappers to close - let wrappers: Vec = instances.drain().map(|(_, v)| v).collect(); - - // Close each connection, continuing on errors to ensure all get closed - let mut last_error: Option = None; - for wrapper in wrappers { - if let Err(e) = wrapper.close().await { - last_error = Some(e.into()); - } - } - - match last_error { - Some(e) => Err(e), - None => Ok(()), - } + close_all_loaded_databases( + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } /// Close database connection and remove all database files diff --git a/src/error.rs b/src/error.rs index 4b1fd77..6aaaa4f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,6 +55,10 @@ pub enum Error { #[error("invalid configuration: {0}")] InvalidConfig(String), + /// Required plugin managed state was not found. + #[error("required plugin state not found: {0}")] + MissingState(String), + /// Generic error for operations that don't fit other categories. #[error("{0}")] Other(String), @@ -94,6 +98,7 @@ impl Error { Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(), Error::TooManySubscriptions(_) => "TOO_MANY_SUBSCRIPTIONS".to_string(), Error::InvalidConfig(_) => "INVALID_CONFIG".to_string(), + Error::MissingState(_) => "MISSING_STATE".to_string(), Error::Other(_) => "ERROR".to_string(), } } @@ -146,6 +151,12 @@ mod tests { ); } + #[test] + fn test_error_code_missing_state() { + let err = Error::MissingState("DbInstances".into()); + assert_eq!(err.error_code(), "MISSING_STATE"); + } + #[test] fn test_error_code_unsupported_datatype() { let err = Error::Toolkit(sqlx_sqlite_toolkit::Error::UnsupportedDatatype( diff --git a/src/lib.rs b/src/lib.rs index 37ded05..fe02152 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,8 @@ pub use sqlx_sqlite_toolkit::{ TransactionExecutionBuilder, WriteQueryResult, }; +use crate::subscriptions::ActiveSubscriptions; + /// Default maximum number of concurrently loaded databases. const DEFAULT_MAX_DATABASES: usize = 50; @@ -548,8 +550,13 @@ impl Builder { let migration_states = app.state::(); { let mut states = migration_states.0.blocking_write(); - for key in database_info_by_key.keys() { - states.insert(key.clone(), MigrationState::new()); + // Only track migration state for databases that have a migrator. + // Keys without migrations are omitted so `await_migrations` returns + // immediately instead of waiting on a Pending state that never runs. + for (key, info) in &database_info_by_key { + if info.migrator.is_some() { + states.insert(key.clone(), MigrationState::new()); + } } } @@ -631,29 +638,21 @@ impl Builder { let timeout_result = tokio::time::timeout( std::time::Duration::from_secs(5), async { - // First, abort all subscriptions and transactions debug!("Aborting active subscriptions and transactions"); active_subs_clone.abort_all().await; - sqlx_sqlite_toolkit::cleanup_all_transactions(&interruptible_txs_clone, ®ular_txs_clone).await; - - // Close databases (each wrapper's close() disables its own - // observer at the crate level, unregistering SQLite hooks) - let mut guard = instances_clone.inner.write().await; - let wrappers: Vec = - guard.drain().map(|(_, v)| v).collect(); - - // Close databases in parallel - let mut set = tokio::task::JoinSet::new(); - for wrapper in wrappers { - set.spawn(async move { wrapper.close().await }); + if let Err(e) = sqlx_sqlite_toolkit::cleanup_all_transactions( + &interruptible_txs_clone, + ®ular_txs_clone, + ) + .await + { + warn!("Transaction cleanup failed during exit: {e}"); } - while let Some(result) = set.join_next().await { - match result { - Ok(Err(e)) => warn!("Error closing database: {:?}", e), - Err(e) => warn!("Database close task panicked: {:?}", e), - Ok(Ok(())) => {} - } + if let Err(e) = + close_all_wrappers(&instances_clone).await + { + warn!("Error closing databases during exit: {e:?}"); } }, ) @@ -890,6 +889,17 @@ pub trait Connection { database_key: &str, config: SqliteDatabaseConfig, ) -> impl Future> + Send; + + /// Close the loaded instance for a registered database key. + /// + /// Returns `true` if the database was loaded and successfully closed. + /// Returns `false` if the database was not loaded (nothing to close). + /// Returns `Err` if transaction cleanup or pool close fails (database file + /// may not be safe to delete or recreate). + /// Active subscriptions for this key are aborted, and in-flight transactions + /// are cleaned up (interruptible transactions rolled back; regular transaction + /// tasks aborted and awaited) before the connection pool is closed. + fn close(&self, database_key: &str) -> impl Future> + Send; } /// Delegates to [`connect_to_database`]: same open path as the `load` IPC command. @@ -907,6 +917,32 @@ impl Connection for AppHandle { let response = connect_to_database(self, database_key, Some(config)).await?; Ok(response.wrapper) } + + async fn close(&self, database_key: &str) -> Result { + let instances = self + .try_state::() + .ok_or(Error::MissingState("DbInstances".into()))?; + let subs = self + .try_state::() + .ok_or(Error::MissingState("ActiveSubscriptions".into()))?; + let interruptible_txs = + self + .try_state::() + .ok_or(Error::MissingState( + "ActiveInterruptibleTransactions".into(), + ))?; + let regular_txs = self + .try_state::() + .ok_or(Error::MissingState("ActiveRegularTransactions".into()))?; + close_database( + database_key, + &instances, + &subs, + &interruptible_txs, + ®ular_txs, + ) + .await + } } struct ConnectionResponse { @@ -1005,6 +1041,58 @@ async fn await_migrations(migration_states: &MigrationStates, db_key: &str) -> R } } +pub(crate) async fn close_database( + db_key: &str, + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result { + active_subs.remove_for_db(db_key).await; + + sqlx_sqlite_toolkit::cleanup_transactions_for_db(db_key, interruptible_txs, regular_txs).await?; + + let mut instances = db_instances.inner.write().await; + + if let Some(wrapper) = instances.remove(db_key) { + wrapper.close().await?; + Ok(true) + } else { + Ok(false) // Database wasn't loaded + } +} + +async fn close_all_wrappers(db_instances: &DbInstances) -> Result<()> { + let mut instances = db_instances.inner.write().await; + let wrappers: Vec = instances.drain().map(|(_, v)| v).collect(); + drop(instances); + + let mut last_error: Option = None; + for wrapper in wrappers { + if let Err(e) = wrapper.close().await { + last_error = Some(e.into()); + } + } + + match last_error { + Some(e) => Err(e), + None => Ok(()), + } +} + +/// Close all loaded database instances after aborting subscriptions and +/// cleaning up in-flight transactions. +pub(crate) async fn close_all_loaded_databases( + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + active_subs.abort_all().await; + sqlx_sqlite_toolkit::cleanup_all_transactions(interruptible_txs, regular_txs).await?; + close_all_wrappers(db_instances).await +} + /// Resolve a registered database path by key. /// /// The `db_key` must match a key registered via @@ -1024,10 +1112,112 @@ fn resolve_database_path(db_key: &str, app: &AppHandle) -> Result #[cfg(test)] mod tests { use super::*; + use crate::commands; use std::collections::HashMap; - use tauri::Manager; use tauri::plugin::Plugin; use tauri::test::{MockRuntime, mock_app, mock_builder, mock_context, noop_assets}; + use uuid::Uuid; + + /// Build and initialize the plugin for a single registered database. + /// + /// Must run **outside** a Tokio runtime context (or on a `spawn_blocking` thread). + /// Plugin `.setup()` calls `tokio::sync::RwLock::blocking_write()`, which panics + /// if invoked while the current thread is already driving async tasks (e.g. inside + /// `#[tokio::test]`). Integration tests below use `#[test]` + `block_on` and run + /// initialization via `spawn_blocking` before awaiting commands. + fn init_app_with_registered_db_at_path( + key: &str, + db_path: PathBuf, + ) -> (tauri::App, PathBuf) { + let mut plugin = Builder::::new() + .register_database(key, &db_path, None) + .unwrap() + .build() + .unwrap(); + let app = mock_app(); + plugin + .initialize(app.handle(), serde_json::Value::default()) + .expect("plugin init should succeed"); + (app, db_path) + } + + fn init_app_with_main_and_other( + main_path: PathBuf, + other_path: PathBuf, + ) -> tauri::App { + let mut plugin = Builder::::new() + .register_database("MAIN", &main_path, None) + .unwrap() + .register_database("OTHER", &other_path, None) + .unwrap() + .build() + .unwrap(); + let app = mock_app(); + plugin + .initialize(app.handle(), serde_json::Value::default()) + .expect("plugin init should succeed"); + app + } + + async fn load_and_create_test_table(app: &tauri::App, db_key: &str) { + connect_to_database(app.handle(), db_key, None) + .await + .expect("connect should succeed"); + + commands::execute( + app.state::(), + db_key.to_string(), + "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)".to_string(), + vec![], + None, + ) + .await + .expect("create table should succeed"); + } + + /// Holds a writer mid-transaction so `close` must abort with a checked-out connection. + async fn spawn_tracked_mid_write_regular_transaction( + app: &tauri::App, + ) -> tokio::sync::oneshot::Receiver<()> { + use sqlx_sqlite_toolkit::TransactionWriter; + + let wrapper = { + let instances = app.state::().inner().inner.read().await; + instances + .get("MAIN") + .expect("MAIN should be loaded") + .clone() + }; + let tx_key = format!("MAIN:{}", Uuid::new_v4()); + let regular_txs = app.state::().inner().clone(); + let (started_tx, started_rx) = tokio::sync::oneshot::channel(); + let tx_key_for_task = tx_key.clone(); + + let handle = tokio::spawn(async move { + let guard = wrapper.acquire_writer().await.expect("acquire writer"); + let mut writer = TransactionWriter::from(guard); + writer + .begin_immediate() + .await + .expect("begin immediate should succeed"); + writer + .execute_query( + sqlx::query("INSERT INTO test (val) VALUES (?)").bind("should-not-commit"), + ) + .await + .expect("insert should succeed"); + started_tx.send(()).ok(); + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + regular_txs.remove(&tx_key_for_task).await; + }); + + app.state::() + .inner() + .insert(tx_key, handle) + .await; + + started_rx + } fn builder_with_duplicate_paths(temp_dir: &tempfile::TempDir) -> Builder { let path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); @@ -1351,4 +1541,332 @@ mod tests { Some(std::time::Duration::from_secs(1)) ); } + + #[test] + fn test_connect_without_migrator_does_not_wait_on_migration_state() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + // `plugin.initialize()` must not run on the runtime worker thread — see + // `init_app_with_registered_db_at_path` for why we use `spawn_blocking`. + let (app, path) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + let migration_states = app.state::(); + assert!( + !migration_states.0.read().await.contains_key("MAIN"), + "databases without a migrator should not have migration state" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("connect should not block on migration state"); + assert_eq!(response.path, path); + }); + } + + #[test] + fn test_close_rolls_back_interruptible_transaction() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + let (app, _) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("uncommitted")], + }], + None, + ) + .await + .expect("begin interruptible transaction should succeed"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none(), + "close should remove the loaded instance" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close should succeed"); + + assert!( + rows.is_empty(), + "close should roll back interruptible transaction before closing" + ); + }); + } + + #[test] + fn test_close_aborts_in_flight_regular_transaction() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + let (app, path) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + + let started_rx = spawn_tracked_mid_write_regular_transaction(&app).await; + started_rx + .await + .expect("regular transaction should hold writer mid-flight"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none(), + "close should remove the loaded instance" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close should succeed"); + + assert!( + rows.is_empty(), + "close should abort mid-write regular transaction before closing" + ); + + std::fs::remove_file(&path).expect("database file should be safe to delete after close"); + }); + } + + #[test] + fn test_close_all_cleans_up_transactions() { + let temp_dir = tempfile::tempdir().unwrap(); + let main_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let other_path = validate::validate_database_path(temp_dir.path().join("other.db")).unwrap(); + + let main_path_for_delete = main_path.clone(); + + tauri::async_runtime::block_on(async { + let app = tokio::task::spawn_blocking(move || { + init_app_with_main_and_other(main_path, other_path) + }) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + load_and_create_test_table(&app, "OTHER").await; + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("uncommitted")], + }], + None, + ) + .await + .expect("begin interruptible transaction should succeed"); + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "OTHER".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("other-uncommitted")], + }], + None, + ) + .await + .expect("begin OTHER interruptible transaction should succeed"); + + commands::close_all( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + ) + .await + .expect("close_all should succeed"); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .is_empty(), + "close_all should remove all loaded instances" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close_all should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close_all should succeed"); + + assert!( + rows.is_empty(), + "close_all should roll back interruptible transactions before closing" + ); + + std::fs::remove_file(&main_path_for_delete) + .expect("database file should be safe to delete after close_all"); + }); + } + + #[test] + fn test_close_only_aborts_transactions_for_target_database() { + let temp_dir = tempfile::tempdir().unwrap(); + let main_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let other_path = validate::validate_database_path(temp_dir.path().join("other.db")).unwrap(); + + tauri::async_runtime::block_on(async { + let app = tokio::task::spawn_blocking(move || { + init_app_with_main_and_other(main_path, other_path) + }) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + load_and_create_test_table(&app, "OTHER").await; + + let main_token = commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("main-uncommitted")], + }], + None, + ) + .await + .expect("begin MAIN interruptible transaction should succeed"); + + let other_token = commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "OTHER".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("other-uncommitted")], + }], + None, + ) + .await + .expect("begin OTHER interruptible transaction should succeed"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none() + ); + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("OTHER") + .is_some() + ); + + let err = commands::transaction_continue( + app.state::(), + main_token, + commands::TransactionAction::Commit, + ) + .await + .expect_err("MAIN transaction should have been aborted on close"); + assert!(matches!( + err, + Error::Toolkit(sqlx_sqlite_toolkit::Error::NoActiveTransaction(_)) + )); + + let continued = commands::transaction_continue( + app.state::(), + other_token, + commands::TransactionAction::Rollback, + ) + .await + .expect("OTHER transaction should still be active"); + assert!(continued.is_none()); + }); + } }