diff --git a/CHANGELOG.md b/CHANGELOG.md index 330fe9c..1cae6c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,20 @@ Databases must be registered on the Rust side with a stable key before they can - Parent directory auto-creation during registration validation for file paths. - CI check that committed `api-iife.js` matches a fresh Rollup build. +### Changed + +#### `close` aborts active transactions before closing + +`Database.close()` (IPC `plugin:sqlite|close`), `Database.close_all()` (IPC `plugin:sqlite|close_all`), and the Rust [`Connection::close`](src/lib.rs) API now roll back or cancel in-flight transactions before closing connection pools. + +- **Interruptible transactions** (`beginInterruptibleTransaction`): explicitly rolled back via `ROLLBACK` before the pool is closed. +- **Regular transactions** (`executeTransaction`): the in-flight task is aborted and awaited so pooled connections are released before the pool closes. + +Previously, `close` only aborted active subscriptions; open transactions could block a clean shutdown or leave uncommitted work on pooled connections. The same cleanup logic is available to Rust callers through [`close_database`](src/lib.rs) and [`close_all_loaded_databases`](src/lib.rs). + +Transaction cleanup failures propagate as errors rather than being logged and ignored, so a successful close indicates the database file is safe to delete or recreate. + ### Fixed +- `load()` no longer hangs forever for databases registered without a migrator. Migration state is only tracked when a migrator is provided; otherwise `await_migrations` proceeds immediately. - Regenerated `api-iife.js` so all IPC calls use `dbKey` consistently. diff --git a/crates/sqlx-sqlite-toolkit/src/lib.rs b/crates/sqlx-sqlite-toolkit/src/lib.rs index 201ba2d..627801c 100644 --- a/crates/sqlx-sqlite-toolkit/src/lib.rs +++ b/crates/sqlx-sqlite-toolkit/src/lib.rs @@ -46,7 +46,7 @@ pub use error::{Error, Result}; pub use pagination::{KeysetColumn, KeysetPage, SortDirection}; pub use transactions::{ ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, - Statement, TransactionWriter, cleanup_all_transactions, + Statement, TransactionWriter, cleanup_all_transactions, cleanup_transactions_for_db, }; pub use wrapper::{ DatabaseWrapper, InterruptibleTransaction, InterruptibleTransactionBuilder, diff --git a/crates/sqlx-sqlite-toolkit/src/transactions.rs b/crates/sqlx-sqlite-toolkit/src/transactions.rs index 559d9eb..db78056 100644 --- a/crates/sqlx-sqlite-toolkit/src/transactions.rs +++ b/crates/sqlx-sqlite-toolkit/src/transactions.rs @@ -10,7 +10,7 @@ use serde_json::Value as JsonValue; use sqlx::{Column, Row}; use sqlx_sqlite_conn_mgr::{AttachedWriteGuard, WriteGuard}; use tokio::sync::{Mutex, RwLock}; -use tokio::task::AbortHandle; +use tokio::task::JoinHandle; use tracing::{debug, warn}; #[cfg(feature = "observer")] @@ -106,6 +106,8 @@ pub struct ActiveInterruptibleTransaction { // call sqlx's rt::spawn and panic with "this functionality requires a Tokio // context". runtime_handle: tokio::runtime::Handle, + #[cfg(test)] + force_rollback_failure: bool, } impl ActiveInterruptibleTransaction { @@ -121,9 +123,18 @@ impl ActiveInterruptibleTransaction { writer: Some(writer), created_at: Instant::now(), runtime_handle: tokio::runtime::Handle::current(), + #[cfg(test)] + force_rollback_failure: false, } } + /// Test-only: force `rollback()` to fail when aborting via transaction state. + #[cfg(test)] + pub fn force_rollback_failure_for_test(mut self) -> Self { + self.force_rollback_failure = true; + self + } + fn writer_mut(&mut self) -> Result<&mut TransactionWriter> { self .writer @@ -208,6 +219,15 @@ impl ActiveInterruptibleTransaction { /// Rollback this transaction pub async fn rollback(mut self) -> Result<()> { + #[cfg(test)] + if self.force_rollback_failure { + let db_path = self.db_path.clone(); + drop(self.take_writer()?); + return Err(Error::Other(format!( + "forced rollback failure for test (db: {db_path})" + ))); + } + let mut writer = self.take_writer()?; writer.rollback().await?; @@ -369,7 +389,7 @@ impl ActiveInterruptibleTransactions { } } - pub async fn abort_all(&self) { + pub async fn abort_all(&self) -> Result<()> { // Drain under the lock, then release it before awaiting rollbacks so we // don't hold the mutex across a chain of awaits. let drained: Vec<(String, ActiveInterruptibleTransaction)> = { @@ -383,10 +403,28 @@ impl ActiveInterruptibleTransactions { "Rolling back interruptible transaction for database: {}", db_path ); - if let Err(err) = tx.rollback().await { - warn!("rollback during abort_all failed (db: {db_path}): {err}"); - } + tx.rollback().await?; } + + Ok(()) + } + + /// Roll back and remove the interruptible transaction for a single database, if any. + pub async fn abort_for_db(&self, db_key: &str) -> Result<()> { + let maybe_tx = { + let mut txs = self.inner.lock().await; + txs.remove(db_key) + }; + + if let Some(tx) = maybe_tx { + debug!( + "Rolling back interruptible transaction for database: {}", + db_key + ); + tx.rollback().await?; + } + + Ok(()) } /// Remove and return transaction for commit/rollback. @@ -435,14 +473,32 @@ impl ActiveInterruptibleTransactions { /// Tracking for regular (non-pausable) transactions that are in-flight. /// -/// Holds abort handles so transactions can be cancelled on app exit. +/// Holds join handles so transactions can be cancelled and awaited on close. #[derive(Clone, Default)] -pub struct ActiveRegularTransactions(Arc>>); +pub struct ActiveRegularTransactions(Arc>>>); + +async fn abort_and_await_regular_handles(handles: Vec<(String, JoinHandle<()>)>) -> Result<()> { + for (key, handle) in handles { + debug!("Aborting regular transaction: {}", key); + handle.abort(); + match handle.await { + Ok(()) => {} + Err(e) if e.is_cancelled() => {} + Err(e) => { + return Err(Error::Other(format!( + "regular transaction task panicked: {e}" + ))); + } + } + } + + Ok(()) +} impl ActiveRegularTransactions { - pub async fn insert(&self, key: String, abort_handle: AbortHandle) { + pub async fn insert(&self, key: String, handle: JoinHandle<()>) { let mut txs = self.0.write().await; - txs.insert(key, abort_handle); + txs.insert(key, handle); } pub async fn remove(&self, key: &str) { @@ -450,16 +506,25 @@ impl ActiveRegularTransactions { txs.remove(key); } - pub async fn abort_all(&self) { - let mut txs = self.0.write().await; - debug!("Aborting {} active regular transaction(s)", txs.len()); + pub async fn abort_all(&self) -> Result<()> { + let handles: Vec<(String, JoinHandle<()>)> = { + let mut txs = self.0.write().await; + debug!("Aborting {} active regular transaction(s)", txs.len()); + txs.drain().collect() + }; - for (key, abort_handle) in txs.iter() { - debug!("Aborting regular transaction: {}", key); - abort_handle.abort(); - } + abort_and_await_regular_handles(handles).await + } - txs.clear(); + /// Abort in-flight regular transactions for a single database. + pub async fn abort_for_db(&self, db_key: &str) -> Result<()> { + let prefix = format!("{db_key}:"); + let handles: Vec<(String, JoinHandle<()>)> = { + let mut txs = self.0.write().await; + txs.extract_if(|key, _| key.starts_with(&prefix)).collect() + }; + + abort_and_await_regular_handles(handles).await } } @@ -467,11 +532,97 @@ impl ActiveRegularTransactions { pub async fn cleanup_all_transactions( interruptible: &ActiveInterruptibleTransactions, regular: &ActiveRegularTransactions, -) { +) -> Result<()> { debug!("Cleaning up all active transactions"); - interruptible.abort_all().await; - regular.abort_all().await; + interruptible.abort_all().await?; + regular.abort_all().await?; - debug!("Transaction cleanup initiated"); + debug!("Transaction cleanup complete"); + Ok(()) +} + +pub async fn cleanup_transactions_for_db( + db_key: &str, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + interruptible_txs.abort_for_db(db_key).await?; + regular_txs.abort_for_db(db_key).await?; + Ok(()) +} + +#[cfg(test)] +mod abort_error_tests { + use super::*; + use crate::DatabaseWrapper; + use serde_json::json; + + async fn begin_test_transaction( + db: &DatabaseWrapper, + db_path: &str, + ) -> ActiveInterruptibleTransaction { + let guard = db.acquire_writer().await.unwrap(); + let mut writer = TransactionWriter::from(guard); + writer.begin_immediate().await.unwrap(); + ActiveInterruptibleTransaction::new( + db_path.to_string(), + uuid::Uuid::new_v4().to_string(), + writer, + ) + } + + #[tokio::test] + async fn test_abort_for_db_propagates_rollback_failure() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("fail.db"); + let db = DatabaseWrapper::connect(&db_path, None).await.unwrap(); + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(), + vec![], + ) + .await + .unwrap(); + + let state = ActiveInterruptibleTransactions::default(); + let mut tx = begin_test_transaction(&db, "fail.db").await; + tx.continue_with(vec![( + "INSERT INTO t (val) VALUES (?)", + vec![json!("uncommitted")], + )]) + .await + .unwrap(); + let tx = tx.force_rollback_failure_for_test(); + state.insert("fail.db".into(), tx).await.unwrap(); + + let err = state.abort_for_db("fail.db").await.unwrap_err(); + assert!(err.to_string().contains("forced rollback failure")); + } + + #[tokio::test] + async fn test_abort_all_propagates_rollback_failure() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("fail-all.db"); + let db = DatabaseWrapper::connect(&db_path, None).await.unwrap(); + db.execute( + "CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(), + vec![], + ) + .await + .unwrap(); + + let state = ActiveInterruptibleTransactions::default(); + let mut tx = begin_test_transaction(&db, "fail-all.db").await; + tx.continue_with(vec![( + "INSERT INTO t (val) VALUES (?)", + vec![json!("uncommitted")], + )]) + .await + .unwrap(); + let tx = tx.force_rollback_failure_for_test(); + state.insert("fail-all.db".into(), tx).await.unwrap(); + + let err = state.abort_all().await.unwrap_err(); + assert!(err.to_string().contains("forced rollback failure")); + } } diff --git a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs index 7ac675d..52023b6 100644 --- a/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs +++ b/crates/sqlx-sqlite-toolkit/tests/transaction_state_tests.rs @@ -134,13 +134,41 @@ async fn test_abort_all_clears_transactions() { let tx_id = tx.transaction_id().to_string(); state.insert("abort.db".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // After abort_all, remove should fail (transaction was cleared) let err = expect_err(state.remove("abort.db", &tx_id).await); assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); } +#[tokio::test] +async fn test_abort_for_db_clears_only_matching_interruptible() { + let (db1, _temp1) = create_test_db("main.db").await; + let (db2, _temp2) = create_test_db("other.db").await; + + for db in [&db1, &db2] { + db.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)".into(), vec![]) + .await + .unwrap(); + } + + let state = ActiveInterruptibleTransactions::default(); + let main_tx = begin_transaction(&db1, "main").await; + let main_tx_id = main_tx.transaction_id().to_string(); + let other_tx = begin_transaction(&db2, "other").await; + let other_tx_id = other_tx.transaction_id().to_string(); + + state.insert("main".into(), main_tx).await.unwrap(); + state.insert("other".into(), other_tx).await.unwrap(); + + state.abort_for_db("main").await.unwrap(); + + let err = expect_err(state.remove("main", &main_tx_id).await); + assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); + + assert!(state.remove("other", &other_tx_id).await.is_ok()); +} + #[tokio::test] async fn test_abort_all_auto_rollbacks_uncommitted_writes() { let (db, _temp) = create_test_db("rollback.db").await; @@ -165,7 +193,7 @@ async fn test_abort_all_auto_rollbacks_uncommitted_writes() { // Store and abort (should auto-rollback on drop) state.insert("rollback.db".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // The uncommitted write should not be visible let rows = db @@ -195,7 +223,7 @@ async fn test_insert_after_abort_all_succeeds() { let tx = begin_transaction(&db1, "reuse-key").await; state.insert("reuse-key".into(), tx).await.unwrap(); - state.abort_all().await; + state.abort_all().await.unwrap(); // Should be able to insert again after abort let tx2 = begin_transaction(&db2, "reuse-key").await; @@ -285,7 +313,7 @@ async fn test_regular_insert_and_remove() { let state = ActiveRegularTransactions::default(); let handle = tokio::spawn(async { /* no-op */ }); - state.insert("tx-1".into(), handle.abort_handle()).await; + state.insert("tx-1".into(), handle).await; // Remove should succeed (no panic, no error) state.remove("tx-1").await; @@ -298,18 +326,12 @@ async fn test_regular_insert_and_remove() { async fn test_regular_abort_all_cancels_tasks() { let state = ActiveRegularTransactions::default(); - // Spawn a long-running task let handle = tokio::spawn(async { tokio::time::sleep(std::time::Duration::from_secs(60)).await; }); - let abort_handle = handle.abort_handle(); - state.insert("long-task".into(), abort_handle).await; - state.abort_all().await; - - // The task should have been aborted - let result = handle.await; - assert!(result.unwrap_err().is_cancelled()); + state.insert("long-task".into(), handle).await; + state.abort_all().await.unwrap(); } #[tokio::test] @@ -319,14 +341,34 @@ async fn test_regular_abort_all_clears_state() { let h1 = tokio::spawn(async {}); let h2 = tokio::spawn(async {}); - state.insert("a".into(), h1.abort_handle()).await; - state.insert("b".into(), h2.abort_handle()).await; + state.insert("a".into(), h1).await; + state.insert("b".into(), h2).await; - state.abort_all().await; + state.abort_all().await.unwrap(); // State should be empty — inserting new keys should work let h3 = tokio::spawn(async {}); - state.insert("a".into(), h3.abort_handle()).await; + state.insert("a".into(), h3).await; +} + +#[tokio::test] +async fn test_regular_abort_for_db_only_matching_prefix() { + let state = ActiveRegularTransactions::default(); + + let main_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + let other_handle = tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + state.insert("main:one".into(), main_handle).await; + state.insert("other:two".into(), other_handle).await; + + state.abort_for_db("main").await.unwrap(); + + // Other database transaction should still be tracked and abortable. + state.abort_for_db("other").await.unwrap(); } // ============================================================================ @@ -352,17 +394,14 @@ async fn test_cleanup_all_transactions() { let handle = tokio::spawn(async { tokio::time::sleep(std::time::Duration::from_secs(60)).await; }); - regular - .insert("regular-1".into(), handle.abort_handle()) - .await; + regular.insert("regular-1".into(), handle).await; // Cleanup should clear both - cleanup_all_transactions(&interruptible, ®ular).await; + cleanup_all_transactions(&interruptible, ®ular) + .await + .unwrap(); // Interruptible should be empty let err = expect_err(interruptible.remove("cleanup.db", "any").await); assert_eq!(err.error_code(), "NO_ACTIVE_TRANSACTION"); - - // Regular task should be cancelled - assert!(handle.await.unwrap_err().is_cancelled()); } diff --git a/src/commands.rs b/src/commands.rs index 57f5e9a..54ee596 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -18,13 +18,13 @@ use tauri::{AppHandle, Runtime, State}; use tracing::debug; use uuid::Uuid; -use crate::connect_to_database; use crate::{ DbInstances, Error, MigrationEvent, MigrationStates, Result, subscriptions::{ ActiveSubscriptions, ObserverConfigParams, TableChangePayload, event_to_payload, }, }; +use crate::{close_all_loaded_databases, close_database, connect_to_database}; /// Token representing an active interruptible transaction #[derive(Debug, Clone, Serialize, Deserialize)] @@ -174,10 +174,11 @@ pub async fn execute_transaction( None }; - // Spawn transaction execution with abort handle for cleanup on exit + // Spawn transaction execution with join handle tracked for cleanup on close let wrapper_clone = wrapper.clone(); let tx_key_clone = tx_key.clone(); let regular_txs_clone = regular_txs.inner().clone(); + let (result_tx, result_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn(async move { // Convert String to &str for execute_transaction @@ -196,27 +197,18 @@ pub async fn execute_transaction( // Remove from tracking when complete (even if result is Err) regular_txs_clone.remove(&tx_key_clone).await; - - result + let _ = result_tx.send(result); }); - // Track abort handle for cleanup on app exit - regular_txs - .insert(tx_key.clone(), handle.abort_handle()) - .await; + regular_txs.insert(tx_key.clone(), handle).await; - // Wait for transaction to complete - match handle.await { + match result_rx.await { Ok(result) => Ok(result?), - Err(e) => { - // Task panicked or was aborted - ensure cleanup + Err(_) => { regular_txs.remove(&tx_key).await; - - if e.is_cancelled() { - Err(Error::Other("Transaction aborted due to app exit".into())) - } else { - Err(Error::Other(format!("Transaction task panicked: {}", e))) - } + Err(Error::Other( + "Transaction task dropped before completion".into(), + )) } } } @@ -322,57 +314,53 @@ pub async fn fetch_page( Ok(result) } -/// Close a specific database connection +/// Close the loaded instance for a registered database key. /// /// Returns `true` if the database was loaded and successfully closed. /// Returns `false` if the database was not loaded (nothing to close). -/// Any active subscriptions for this database are aborted before closing. +/// Returns `Err` if transaction cleanup or pool close fails (database file +/// may not be safe to delete or recreate). +/// Active subscriptions for this key are aborted, and in-flight transactions +/// are cleaned up (interruptible transactions rolled back; regular transaction +/// tasks aborted and awaited) before the connection pool is closed. #[tauri::command] pub async fn close( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, db_key: String, ) -> Result { - active_subs.remove_for_db(&db_key).await; - - let mut instances = db_instances.inner.write().await; - - if let Some(wrapper) = instances.remove(&db_key) { - wrapper.close().await?; - Ok(true) - } else { - Ok(false) // Database wasn't loaded - } + close_database( + &db_key, + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } -/// Close all database connections +/// Close all database connections. /// -/// All active subscriptions are aborted before closing. Each wrapper's -/// `close()` handles disabling its own observer at the crate level. +/// All active subscriptions are aborted and in-flight transactions are cleaned +/// up (interruptible transactions rolled back; regular transaction tasks +/// aborted and awaited) before connection pools are closed. +/// Returns `Err` if transaction cleanup or any pool close fails. #[tauri::command] pub async fn close_all( db_instances: State<'_, DbInstances>, active_subs: State<'_, ActiveSubscriptions>, + interruptible_txs: State<'_, ActiveInterruptibleTransactions>, + regular_txs: State<'_, ActiveRegularTransactions>, ) -> Result<()> { - active_subs.abort_all().await; - - let mut instances = db_instances.inner.write().await; - - // Collect all wrappers to close - let wrappers: Vec = instances.drain().map(|(_, v)| v).collect(); - - // Close each connection, continuing on errors to ensure all get closed - let mut last_error: Option = None; - for wrapper in wrappers { - if let Err(e) = wrapper.close().await { - last_error = Some(e.into()); - } - } - - match last_error { - Some(e) => Err(e), - None => Ok(()), - } + close_all_loaded_databases( + &db_instances, + &active_subs, + &interruptible_txs, + ®ular_txs, + ) + .await } /// Close database connection and remove all database files diff --git a/src/error.rs b/src/error.rs index 4b1fd77..6aaaa4f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -55,6 +55,10 @@ pub enum Error { #[error("invalid configuration: {0}")] InvalidConfig(String), + /// Required plugin managed state was not found. + #[error("required plugin state not found: {0}")] + MissingState(String), + /// Generic error for operations that don't fit other categories. #[error("{0}")] Other(String), @@ -94,6 +98,7 @@ impl Error { Error::TooManyDatabases(_) => "TOO_MANY_DATABASES".to_string(), Error::TooManySubscriptions(_) => "TOO_MANY_SUBSCRIPTIONS".to_string(), Error::InvalidConfig(_) => "INVALID_CONFIG".to_string(), + Error::MissingState(_) => "MISSING_STATE".to_string(), Error::Other(_) => "ERROR".to_string(), } } @@ -146,6 +151,12 @@ mod tests { ); } + #[test] + fn test_error_code_missing_state() { + let err = Error::MissingState("DbInstances".into()); + assert_eq!(err.error_code(), "MISSING_STATE"); + } + #[test] fn test_error_code_unsupported_datatype() { let err = Error::Toolkit(sqlx_sqlite_toolkit::Error::UnsupportedDatatype( diff --git a/src/lib.rs b/src/lib.rs index 37ded05..fe02152 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,8 @@ pub use sqlx_sqlite_toolkit::{ TransactionExecutionBuilder, WriteQueryResult, }; +use crate::subscriptions::ActiveSubscriptions; + /// Default maximum number of concurrently loaded databases. const DEFAULT_MAX_DATABASES: usize = 50; @@ -548,8 +550,13 @@ impl Builder { let migration_states = app.state::(); { let mut states = migration_states.0.blocking_write(); - for key in database_info_by_key.keys() { - states.insert(key.clone(), MigrationState::new()); + // Only track migration state for databases that have a migrator. + // Keys without migrations are omitted so `await_migrations` returns + // immediately instead of waiting on a Pending state that never runs. + for (key, info) in &database_info_by_key { + if info.migrator.is_some() { + states.insert(key.clone(), MigrationState::new()); + } } } @@ -631,29 +638,21 @@ impl Builder { let timeout_result = tokio::time::timeout( std::time::Duration::from_secs(5), async { - // First, abort all subscriptions and transactions debug!("Aborting active subscriptions and transactions"); active_subs_clone.abort_all().await; - sqlx_sqlite_toolkit::cleanup_all_transactions(&interruptible_txs_clone, ®ular_txs_clone).await; - - // Close databases (each wrapper's close() disables its own - // observer at the crate level, unregistering SQLite hooks) - let mut guard = instances_clone.inner.write().await; - let wrappers: Vec = - guard.drain().map(|(_, v)| v).collect(); - - // Close databases in parallel - let mut set = tokio::task::JoinSet::new(); - for wrapper in wrappers { - set.spawn(async move { wrapper.close().await }); + if let Err(e) = sqlx_sqlite_toolkit::cleanup_all_transactions( + &interruptible_txs_clone, + ®ular_txs_clone, + ) + .await + { + warn!("Transaction cleanup failed during exit: {e}"); } - while let Some(result) = set.join_next().await { - match result { - Ok(Err(e)) => warn!("Error closing database: {:?}", e), - Err(e) => warn!("Database close task panicked: {:?}", e), - Ok(Ok(())) => {} - } + if let Err(e) = + close_all_wrappers(&instances_clone).await + { + warn!("Error closing databases during exit: {e:?}"); } }, ) @@ -890,6 +889,17 @@ pub trait Connection { database_key: &str, config: SqliteDatabaseConfig, ) -> impl Future> + Send; + + /// Close the loaded instance for a registered database key. + /// + /// Returns `true` if the database was loaded and successfully closed. + /// Returns `false` if the database was not loaded (nothing to close). + /// Returns `Err` if transaction cleanup or pool close fails (database file + /// may not be safe to delete or recreate). + /// Active subscriptions for this key are aborted, and in-flight transactions + /// are cleaned up (interruptible transactions rolled back; regular transaction + /// tasks aborted and awaited) before the connection pool is closed. + fn close(&self, database_key: &str) -> impl Future> + Send; } /// Delegates to [`connect_to_database`]: same open path as the `load` IPC command. @@ -907,6 +917,32 @@ impl Connection for AppHandle { let response = connect_to_database(self, database_key, Some(config)).await?; Ok(response.wrapper) } + + async fn close(&self, database_key: &str) -> Result { + let instances = self + .try_state::() + .ok_or(Error::MissingState("DbInstances".into()))?; + let subs = self + .try_state::() + .ok_or(Error::MissingState("ActiveSubscriptions".into()))?; + let interruptible_txs = + self + .try_state::() + .ok_or(Error::MissingState( + "ActiveInterruptibleTransactions".into(), + ))?; + let regular_txs = self + .try_state::() + .ok_or(Error::MissingState("ActiveRegularTransactions".into()))?; + close_database( + database_key, + &instances, + &subs, + &interruptible_txs, + ®ular_txs, + ) + .await + } } struct ConnectionResponse { @@ -1005,6 +1041,58 @@ async fn await_migrations(migration_states: &MigrationStates, db_key: &str) -> R } } +pub(crate) async fn close_database( + db_key: &str, + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result { + active_subs.remove_for_db(db_key).await; + + sqlx_sqlite_toolkit::cleanup_transactions_for_db(db_key, interruptible_txs, regular_txs).await?; + + let mut instances = db_instances.inner.write().await; + + if let Some(wrapper) = instances.remove(db_key) { + wrapper.close().await?; + Ok(true) + } else { + Ok(false) // Database wasn't loaded + } +} + +async fn close_all_wrappers(db_instances: &DbInstances) -> Result<()> { + let mut instances = db_instances.inner.write().await; + let wrappers: Vec = instances.drain().map(|(_, v)| v).collect(); + drop(instances); + + let mut last_error: Option = None; + for wrapper in wrappers { + if let Err(e) = wrapper.close().await { + last_error = Some(e.into()); + } + } + + match last_error { + Some(e) => Err(e), + None => Ok(()), + } +} + +/// Close all loaded database instances after aborting subscriptions and +/// cleaning up in-flight transactions. +pub(crate) async fn close_all_loaded_databases( + db_instances: &DbInstances, + active_subs: &ActiveSubscriptions, + interruptible_txs: &ActiveInterruptibleTransactions, + regular_txs: &ActiveRegularTransactions, +) -> Result<()> { + active_subs.abort_all().await; + sqlx_sqlite_toolkit::cleanup_all_transactions(interruptible_txs, regular_txs).await?; + close_all_wrappers(db_instances).await +} + /// Resolve a registered database path by key. /// /// The `db_key` must match a key registered via @@ -1024,10 +1112,112 @@ fn resolve_database_path(db_key: &str, app: &AppHandle) -> Result #[cfg(test)] mod tests { use super::*; + use crate::commands; use std::collections::HashMap; - use tauri::Manager; use tauri::plugin::Plugin; use tauri::test::{MockRuntime, mock_app, mock_builder, mock_context, noop_assets}; + use uuid::Uuid; + + /// Build and initialize the plugin for a single registered database. + /// + /// Must run **outside** a Tokio runtime context (or on a `spawn_blocking` thread). + /// Plugin `.setup()` calls `tokio::sync::RwLock::blocking_write()`, which panics + /// if invoked while the current thread is already driving async tasks (e.g. inside + /// `#[tokio::test]`). Integration tests below use `#[test]` + `block_on` and run + /// initialization via `spawn_blocking` before awaiting commands. + fn init_app_with_registered_db_at_path( + key: &str, + db_path: PathBuf, + ) -> (tauri::App, PathBuf) { + let mut plugin = Builder::::new() + .register_database(key, &db_path, None) + .unwrap() + .build() + .unwrap(); + let app = mock_app(); + plugin + .initialize(app.handle(), serde_json::Value::default()) + .expect("plugin init should succeed"); + (app, db_path) + } + + fn init_app_with_main_and_other( + main_path: PathBuf, + other_path: PathBuf, + ) -> tauri::App { + let mut plugin = Builder::::new() + .register_database("MAIN", &main_path, None) + .unwrap() + .register_database("OTHER", &other_path, None) + .unwrap() + .build() + .unwrap(); + let app = mock_app(); + plugin + .initialize(app.handle(), serde_json::Value::default()) + .expect("plugin init should succeed"); + app + } + + async fn load_and_create_test_table(app: &tauri::App, db_key: &str) { + connect_to_database(app.handle(), db_key, None) + .await + .expect("connect should succeed"); + + commands::execute( + app.state::(), + db_key.to_string(), + "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)".to_string(), + vec![], + None, + ) + .await + .expect("create table should succeed"); + } + + /// Holds a writer mid-transaction so `close` must abort with a checked-out connection. + async fn spawn_tracked_mid_write_regular_transaction( + app: &tauri::App, + ) -> tokio::sync::oneshot::Receiver<()> { + use sqlx_sqlite_toolkit::TransactionWriter; + + let wrapper = { + let instances = app.state::().inner().inner.read().await; + instances + .get("MAIN") + .expect("MAIN should be loaded") + .clone() + }; + let tx_key = format!("MAIN:{}", Uuid::new_v4()); + let regular_txs = app.state::().inner().clone(); + let (started_tx, started_rx) = tokio::sync::oneshot::channel(); + let tx_key_for_task = tx_key.clone(); + + let handle = tokio::spawn(async move { + let guard = wrapper.acquire_writer().await.expect("acquire writer"); + let mut writer = TransactionWriter::from(guard); + writer + .begin_immediate() + .await + .expect("begin immediate should succeed"); + writer + .execute_query( + sqlx::query("INSERT INTO test (val) VALUES (?)").bind("should-not-commit"), + ) + .await + .expect("insert should succeed"); + started_tx.send(()).ok(); + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + regular_txs.remove(&tx_key_for_task).await; + }); + + app.state::() + .inner() + .insert(tx_key, handle) + .await; + + started_rx + } fn builder_with_duplicate_paths(temp_dir: &tempfile::TempDir) -> Builder { let path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); @@ -1351,4 +1541,332 @@ mod tests { Some(std::time::Duration::from_secs(1)) ); } + + #[test] + fn test_connect_without_migrator_does_not_wait_on_migration_state() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + // `plugin.initialize()` must not run on the runtime worker thread — see + // `init_app_with_registered_db_at_path` for why we use `spawn_blocking`. + let (app, path) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + let migration_states = app.state::(); + assert!( + !migration_states.0.read().await.contains_key("MAIN"), + "databases without a migrator should not have migration state" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("connect should not block on migration state"); + assert_eq!(response.path, path); + }); + } + + #[test] + fn test_close_rolls_back_interruptible_transaction() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + let (app, _) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("uncommitted")], + }], + None, + ) + .await + .expect("begin interruptible transaction should succeed"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none(), + "close should remove the loaded instance" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close should succeed"); + + assert!( + rows.is_empty(), + "close should roll back interruptible transaction before closing" + ); + }); + } + + #[test] + fn test_close_aborts_in_flight_regular_transaction() { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let key = "MAIN".to_string(); + + tauri::async_runtime::block_on(async { + let (app, path) = + tokio::task::spawn_blocking(move || init_app_with_registered_db_at_path(&key, db_path)) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + + let started_rx = spawn_tracked_mid_write_regular_transaction(&app).await; + started_rx + .await + .expect("regular transaction should hold writer mid-flight"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none(), + "close should remove the loaded instance" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close should succeed"); + + assert!( + rows.is_empty(), + "close should abort mid-write regular transaction before closing" + ); + + std::fs::remove_file(&path).expect("database file should be safe to delete after close"); + }); + } + + #[test] + fn test_close_all_cleans_up_transactions() { + let temp_dir = tempfile::tempdir().unwrap(); + let main_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let other_path = validate::validate_database_path(temp_dir.path().join("other.db")).unwrap(); + + let main_path_for_delete = main_path.clone(); + + tauri::async_runtime::block_on(async { + let app = tokio::task::spawn_blocking(move || { + init_app_with_main_and_other(main_path, other_path) + }) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + load_and_create_test_table(&app, "OTHER").await; + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("uncommitted")], + }], + None, + ) + .await + .expect("begin interruptible transaction should succeed"); + + commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "OTHER".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("other-uncommitted")], + }], + None, + ) + .await + .expect("begin OTHER interruptible transaction should succeed"); + + commands::close_all( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + ) + .await + .expect("close_all should succeed"); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .is_empty(), + "close_all should remove all loaded instances" + ); + + let response = connect_to_database(app.handle(), "MAIN", None) + .await + .expect("reload after close_all should succeed"); + let rows = response + .wrapper + .fetch_all("SELECT val FROM test".into(), vec![]) + .await + .expect("fetch after close_all should succeed"); + + assert!( + rows.is_empty(), + "close_all should roll back interruptible transactions before closing" + ); + + std::fs::remove_file(&main_path_for_delete) + .expect("database file should be safe to delete after close_all"); + }); + } + + #[test] + fn test_close_only_aborts_transactions_for_target_database() { + let temp_dir = tempfile::tempdir().unwrap(); + let main_path = validate::validate_database_path(temp_dir.path().join("main.db")).unwrap(); + let other_path = validate::validate_database_path(temp_dir.path().join("other.db")).unwrap(); + + tauri::async_runtime::block_on(async { + let app = tokio::task::spawn_blocking(move || { + init_app_with_main_and_other(main_path, other_path) + }) + .await + .expect("plugin init task should succeed"); + + load_and_create_test_table(&app, "MAIN").await; + load_and_create_test_table(&app, "OTHER").await; + + let main_token = commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "MAIN".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("main-uncommitted")], + }], + None, + ) + .await + .expect("begin MAIN interruptible transaction should succeed"); + + let other_token = commands::begin_interruptible_transaction( + app.state::(), + app.state::(), + "OTHER".to_string(), + vec![Statement { + query: "INSERT INTO test (val) VALUES (?)".to_string(), + values: vec![serde_json::json!("other-uncommitted")], + }], + None, + ) + .await + .expect("begin OTHER interruptible transaction should succeed"); + + let closed = commands::close( + app.state::(), + app.state::(), + app.state::(), + app.state::(), + "MAIN".to_string(), + ) + .await + .expect("close should succeed"); + assert!(closed); + + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("MAIN") + .is_none() + ); + assert!( + app.state::() + .inner() + .inner + .read() + .await + .get("OTHER") + .is_some() + ); + + let err = commands::transaction_continue( + app.state::(), + main_token, + commands::TransactionAction::Commit, + ) + .await + .expect_err("MAIN transaction should have been aborted on close"); + assert!(matches!( + err, + Error::Toolkit(sqlx_sqlite_toolkit::Error::NoActiveTransaction(_)) + )); + + let continued = commands::transaction_continue( + app.state::(), + other_token, + commands::TransactionAction::Rollback, + ) + .await + .expect("OTHER transaction should still be active"); + assert!(continued.is_none()); + }); + } }