Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ Databases must be registered on the Rust side with a stable key before they can
- Parent directory auto-creation during registration validation for file paths.
- CI check that committed `api-iife.js` matches a fresh Rollup build.

### Changed

#### `close` aborts active transactions before closing

`Database.close()` (IPC `plugin:sqlite|close`), `Database.close_all()` (IPC `plugin:sqlite|close_all`), and the Rust [`Connection::close`](src/lib.rs) API now roll back or cancel in-flight transactions before closing connection pools.

- **Interruptible transactions** (`beginInterruptibleTransaction`): explicitly rolled back via `ROLLBACK` before the pool is closed.
- **Regular transactions** (`executeTransaction`): the in-flight task is aborted and awaited so pooled connections are released before the pool closes.

Previously, `close` only aborted active subscriptions; open transactions could block a clean shutdown or leave uncommitted work on pooled connections. The same cleanup logic is available to Rust callers through [`close_database`](src/lib.rs) and [`close_all_loaded_databases`](src/lib.rs).

Transaction cleanup failures propagate as errors rather than being logged and ignored, so a successful close indicates the database file is safe to delete or recreate.

### Fixed

- `load()` no longer hangs forever for databases registered without a migrator. Migration state is only tracked when a migrator is provided; otherwise `await_migrations` proceeds immediately.
- Regenerated `api-iife.js` so all IPC calls use `dbKey` consistently.
2 changes: 1 addition & 1 deletion crates/sqlx-sqlite-toolkit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub use error::{Error, Result};
pub use pagination::{KeysetColumn, KeysetPage, SortDirection};
pub use transactions::{
ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions,
Statement, TransactionWriter, cleanup_all_transactions,
Statement, TransactionWriter, cleanup_all_transactions, cleanup_transactions_for_db,
};
pub use wrapper::{
DatabaseWrapper, InterruptibleTransaction, InterruptibleTransactionBuilder,
Expand Down
193 changes: 172 additions & 21 deletions crates/sqlx-sqlite-toolkit/src/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde_json::Value as JsonValue;
use sqlx::{Column, Row};
use sqlx_sqlite_conn_mgr::{AttachedWriteGuard, WriteGuard};
use tokio::sync::{Mutex, RwLock};
use tokio::task::AbortHandle;
use tokio::task::JoinHandle;
use tracing::{debug, warn};

#[cfg(feature = "observer")]
Expand Down Expand Up @@ -106,6 +106,8 @@ pub struct ActiveInterruptibleTransaction {
// call sqlx's rt::spawn and panic with "this functionality requires a Tokio
// context".
runtime_handle: tokio::runtime::Handle,
#[cfg(test)]
force_rollback_failure: bool,
}

impl ActiveInterruptibleTransaction {
Expand All @@ -121,9 +123,18 @@ impl ActiveInterruptibleTransaction {
writer: Some(writer),
created_at: Instant::now(),
runtime_handle: tokio::runtime::Handle::current(),
#[cfg(test)]
force_rollback_failure: false,
}
}

/// Test-only: force `rollback()` to fail when aborting via transaction state.
#[cfg(test)]
pub fn force_rollback_failure_for_test(mut self) -> Self {
self.force_rollback_failure = true;
self
}

fn writer_mut(&mut self) -> Result<&mut TransactionWriter> {
self
.writer
Expand Down Expand Up @@ -208,6 +219,15 @@ impl ActiveInterruptibleTransaction {

/// Rollback this transaction
pub async fn rollback(mut self) -> Result<()> {
#[cfg(test)]
if self.force_rollback_failure {
let db_path = self.db_path.clone();
drop(self.take_writer()?);
return Err(Error::Other(format!(
"forced rollback failure for test (db: {db_path})"
)));
}

let mut writer = self.take_writer()?;
writer.rollback().await?;

Expand Down Expand Up @@ -369,7 +389,7 @@ impl ActiveInterruptibleTransactions {
}
}

pub async fn abort_all(&self) {
pub async fn abort_all(&self) -> Result<()> {
// Drain under the lock, then release it before awaiting rollbacks so we
// don't hold the mutex across a chain of awaits.
let drained: Vec<(String, ActiveInterruptibleTransaction)> = {
Expand All @@ -383,10 +403,28 @@ impl ActiveInterruptibleTransactions {
"Rolling back interruptible transaction for database: {}",
db_path
);
if let Err(err) = tx.rollback().await {
warn!("rollback during abort_all failed (db: {db_path}): {err}");
}
tx.rollback().await?;
}

Ok(())
}

/// Roll back and remove the interruptible transaction for a single database, if any.
pub async fn abort_for_db(&self, db_key: &str) -> Result<()> {
let maybe_tx = {
let mut txs = self.inner.lock().await;
txs.remove(db_key)
};

if let Some(tx) = maybe_tx {
debug!(
"Rolling back interruptible transaction for database: {}",
db_key
);
tx.rollback().await?;
}

Ok(())
}

/// Remove and return transaction for commit/rollback.
Expand Down Expand Up @@ -435,43 +473,156 @@ impl ActiveInterruptibleTransactions {

/// Tracking for regular (non-pausable) transactions that are in-flight.
///
/// Holds abort handles so transactions can be cancelled on app exit.
/// Holds join handles so transactions can be cancelled and awaited on close.
#[derive(Clone, Default)]
pub struct ActiveRegularTransactions(Arc<RwLock<HashMap<String, AbortHandle>>>);
pub struct ActiveRegularTransactions(Arc<RwLock<HashMap<String, JoinHandle<()>>>>);

async fn abort_and_await_regular_handles(handles: Vec<(String, JoinHandle<()>)>) -> Result<()> {
for (key, handle) in handles {
debug!("Aborting regular transaction: {}", key);
handle.abort();
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {}
Err(e) => {
return Err(Error::Other(format!(
"regular transaction task panicked: {e}"
)));
}
}
}

Ok(())
}

impl ActiveRegularTransactions {
pub async fn insert(&self, key: String, abort_handle: AbortHandle) {
pub async fn insert(&self, key: String, handle: JoinHandle<()>) {
let mut txs = self.0.write().await;
txs.insert(key, abort_handle);
txs.insert(key, handle);
}

pub async fn remove(&self, key: &str) {
let mut txs = self.0.write().await;
txs.remove(key);
}

pub async fn abort_all(&self) {
let mut txs = self.0.write().await;
debug!("Aborting {} active regular transaction(s)", txs.len());
pub async fn abort_all(&self) -> Result<()> {
let handles: Vec<(String, JoinHandle<()>)> = {
let mut txs = self.0.write().await;
debug!("Aborting {} active regular transaction(s)", txs.len());
txs.drain().collect()
};

for (key, abort_handle) in txs.iter() {
debug!("Aborting regular transaction: {}", key);
abort_handle.abort();
}
abort_and_await_regular_handles(handles).await
}

txs.clear();
/// Abort in-flight regular transactions for a single database.
pub async fn abort_for_db(&self, db_key: &str) -> Result<()> {
let prefix = format!("{db_key}:");
let handles: Vec<(String, JoinHandle<()>)> = {
let mut txs = self.0.write().await;
txs.extract_if(|key, _| key.starts_with(&prefix)).collect()
};

abort_and_await_regular_handles(handles).await
}
}

/// Cleanup all transactions on app exit.
pub async fn cleanup_all_transactions(
interruptible: &ActiveInterruptibleTransactions,
regular: &ActiveRegularTransactions,
) {
) -> Result<()> {
debug!("Cleaning up all active transactions");

interruptible.abort_all().await;
regular.abort_all().await;
interruptible.abort_all().await?;
regular.abort_all().await?;

debug!("Transaction cleanup initiated");
debug!("Transaction cleanup complete");
Ok(())
}

pub async fn cleanup_transactions_for_db(
Comment thread
onehumandev marked this conversation as resolved.
db_key: &str,
interruptible_txs: &ActiveInterruptibleTransactions,
regular_txs: &ActiveRegularTransactions,
) -> Result<()> {
interruptible_txs.abort_for_db(db_key).await?;
regular_txs.abort_for_db(db_key).await?;
Ok(())
}

#[cfg(test)]
mod abort_error_tests {
use super::*;
use crate::DatabaseWrapper;
use serde_json::json;

async fn begin_test_transaction(
db: &DatabaseWrapper,
db_path: &str,
) -> ActiveInterruptibleTransaction {
let guard = db.acquire_writer().await.unwrap();
let mut writer = TransactionWriter::from(guard);
writer.begin_immediate().await.unwrap();
ActiveInterruptibleTransaction::new(
db_path.to_string(),
uuid::Uuid::new_v4().to_string(),
writer,
)
}

#[tokio::test]
async fn test_abort_for_db_propagates_rollback_failure() {
let temp_dir = tempfile::tempdir().unwrap();
let db_path = temp_dir.path().join("fail.db");
let db = DatabaseWrapper::connect(&db_path, None).await.unwrap();
db.execute(
"CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(),
vec![],
)
.await
.unwrap();

let state = ActiveInterruptibleTransactions::default();
let mut tx = begin_test_transaction(&db, "fail.db").await;
tx.continue_with(vec![(
"INSERT INTO t (val) VALUES (?)",
vec![json!("uncommitted")],
)])
.await
.unwrap();
let tx = tx.force_rollback_failure_for_test();
state.insert("fail.db".into(), tx).await.unwrap();

let err = state.abort_for_db("fail.db").await.unwrap_err();
assert!(err.to_string().contains("forced rollback failure"));
}

#[tokio::test]
async fn test_abort_all_propagates_rollback_failure() {
let temp_dir = tempfile::tempdir().unwrap();
let db_path = temp_dir.path().join("fail-all.db");
let db = DatabaseWrapper::connect(&db_path, None).await.unwrap();
db.execute(
"CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)".into(),
vec![],
)
.await
.unwrap();

let state = ActiveInterruptibleTransactions::default();
let mut tx = begin_test_transaction(&db, "fail-all.db").await;
tx.continue_with(vec![(
"INSERT INTO t (val) VALUES (?)",
vec![json!("uncommitted")],
)])
.await
.unwrap();
let tx = tx.force_rollback_failure_for_test();
state.insert("fail-all.db".into(), tx).await.unwrap();

let err = state.abort_all().await.unwrap_err();
assert!(err.to_string().contains("forced rollback failure"));
}
}
Loading
Loading