diff --git a/src/io/in_memory_store.rs b/src/io/in_memory_store.rs index 8b7d41c84..4e19118ef 100644 --- a/src/io/in_memory_store.rs +++ b/src/io/in_memory_store.rs @@ -11,7 +11,9 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Mutex; use lightning::io; -use lightning::util::persist::{KVStore, PageToken, PaginatedKVStore, PaginatedListResponse}; +use lightning::util::persist::{ + KVStore, MigratableKVStore, PageToken, PaginatedKVStore, PaginatedListResponse, +}; const IN_MEMORY_PAGE_SIZE: usize = 50; @@ -96,6 +98,27 @@ impl InMemoryStore { hash_map::Entry::Vacant(_) => Ok(Vec::new()), } } + + fn list_all_keys_internal(&self) -> io::Result> { + let persisted_lock = self.persisted_bytes.lock().unwrap(); + let mut keys = Vec::new(); + + for (prefixed_namespace, namespace_entries) in persisted_lock.iter() { + let (primary_namespace, secondary_namespace) = + prefixed_namespace.split_once('/').ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "Invalid namespace format") + })?; + for key in namespace_entries.keys() { + keys.push(( + primary_namespace.to_string(), + secondary_namespace.to_string(), + key.clone(), + )); + } + } + + Ok(keys) + } } impl KVStore for InMemoryStore { @@ -187,5 +210,40 @@ impl PaginatedKVStore for InMemoryStore { } } +impl MigratableKVStore for InMemoryStore { + fn list_all_keys( + &self, + ) -> impl Future, io::Error>> + 'static + Send { + let res = self.list_all_keys_internal(); + async move { res } + } +} + unsafe impl Sync for InMemoryStore {} unsafe impl Send for InMemoryStore {} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn in_memory_store_list_all_keys() { + let store = InMemoryStore::new(); + + KVStore::write(&store, "ns_a", "sub_a", "key_a", vec![1u8]).await.unwrap(); + KVStore::write(&store, "ns_a", "sub_b", "key_b", vec![2u8]).await.unwrap(); + KVStore::write(&store, "ns_b", "", "key_c", vec![3u8]).await.unwrap(); + + let mut keys = MigratableKVStore::list_all_keys(&store).await.unwrap(); + keys.sort(); + + assert_eq!( + keys, + vec![ + ("ns_a".to_string(), "sub_a".to_string(), "key_a".to_string()), + ("ns_a".to_string(), "sub_b".to_string(), "key_b".to_string()), + ("ns_b".to_string(), "".to_string(), "key_c".to_string()), + ] + ); + } +} diff --git a/src/io/postgres_store/mod.rs b/src/io/postgres_store/mod.rs index c0770de5f..90b8cdc39 100644 --- a/src/io/postgres_store/mod.rs +++ b/src/io/postgres_store/mod.rs @@ -12,7 +12,9 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use lightning::io; -use lightning::util::persist::{KVStore, PageToken, PaginatedKVStore, PaginatedListResponse}; +use lightning::util::persist::{ + KVStore, MigratableKVStore, PageToken, PaginatedKVStore, PaginatedListResponse, +}; use lightning_types::string::PrintableString; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; @@ -351,6 +353,24 @@ impl PaginatedKVStore for PostgresStore { } } +impl MigratableKVStore for PostgresStore { + fn list_all_keys( + &self, + ) -> impl Future, io::Error>> + 'static + Send { + let inner = Arc::clone(&self.inner); + let runtime = self.internal_runtime(); + async move { + let task = runtime.spawn(async move { inner.list_all_keys_internal().await }); + task.await.map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("PostgreSQL runtime task failed: {}", e), + ) + })? + } + } +} + struct PostgresStoreInner { pool: SmallPool, config: Config, @@ -725,6 +745,25 @@ impl PostgresStoreInner { Ok(keys) } + async fn list_all_keys_internal(&self) -> io::Result> { + let sql = format!( + "SELECT primary_namespace, secondary_namespace, key FROM {}", + self.kv_table_name_sql + ); + + let err_map = |e: PgError| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + }; + + let mut locked = self.locked_client().await?; + let rows = query_with_retry!(self, locked, err_map, locked.query(sql.as_str(), &[]))?; + + let keys: Vec<(String, String, String)> = + rows.iter().map(|row| (row.get(0), row.get(1), row.get(2))).collect(); + Ok(keys) + } + async fn list_paginated_internal( &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, ) -> io::Result { @@ -904,6 +943,29 @@ mod tests { cleanup_store(&store_1).await; } + #[tokio::test(flavor = "multi_thread")] + async fn test_postgres_store_list_all_keys() { + let store = create_test_store("test_pg_list_all_keys").await; + + KVStore::write(&store, "ns_a", "sub_a", "key_a", vec![1u8]).await.unwrap(); + KVStore::write(&store, "ns_a", "sub_b", "key_b", vec![2u8]).await.unwrap(); + KVStore::write(&store, "ns_b", "", "key_c", vec![3u8]).await.unwrap(); + + let mut keys = MigratableKVStore::list_all_keys(&store).await.unwrap(); + keys.sort(); + + assert_eq!( + keys, + vec![ + ("ns_a".to_string(), "sub_a".to_string(), "key_a".to_string()), + ("ns_a".to_string(), "sub_b".to_string(), "key_b".to_string()), + ("ns_b".to_string(), "".to_string(), "key_c".to_string()), + ] + ); + + cleanup_store(&store).await; + } + async fn kill_connection(store: &PostgresStore) { // Terminate every backend in the pool so the next op deterministically // hits a closed connection regardless of which slot `get` selects. diff --git a/src/io/sqlite_store/mod.rs b/src/io/sqlite_store/mod.rs index 076aeef9b..b2d492e85 100644 --- a/src/io/sqlite_store/mod.rs +++ b/src/io/sqlite_store/mod.rs @@ -14,7 +14,9 @@ use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use lightning::io; -use lightning::util::persist::{KVStore, PageToken, PaginatedKVStore, PaginatedListResponse}; +use lightning::util::persist::{ + KVStore, MigratableKVStore, PageToken, PaginatedKVStore, PaginatedListResponse, +}; use lightning_types::string::PrintableString; use rusqlite::{named_params, Connection}; @@ -202,6 +204,21 @@ impl PaginatedKVStore for SqliteStore { } } +impl MigratableKVStore for SqliteStore { + fn list_all_keys( + &self, + ) -> impl Future, io::Error>> + 'static + Send { + let inner = Arc::clone(&self.inner); + let fut = tokio::task::spawn_blocking(move || inner.list_all_keys_internal()); + async move { + fut.await.unwrap_or_else(|e| { + let msg = format!("Failed to IO operation due join error: {}", e); + Err(io::Error::new(io::ErrorKind::Other, msg)) + }) + } + } +} + struct SqliteStoreInner { connection: Arc>, data_dir: PathBuf, @@ -486,6 +503,35 @@ impl SqliteStoreInner { Ok(keys) } + fn list_all_keys_internal(&self) -> io::Result> { + let locked_conn = self.connection.lock().expect("lock"); + + let sql = format!( + "SELECT primary_namespace, secondary_namespace, key FROM {}", + self.kv_table_name + ); + let mut stmt = locked_conn.prepare_cached(&sql).map_err(|e| { + let msg = format!("Failed to prepare statement: {}", e); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + let mut keys = Vec::new(); + let rows_iter = + stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?))).map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {}", e); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + for key in rows_iter { + keys.push(key.map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {}", e); + io::Error::new(io::ErrorKind::Other, msg) + })?); + } + + Ok(keys) + } + fn list_paginated_internal( &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, ) -> io::Result { @@ -679,6 +725,34 @@ mod tests { do_test_store(&store_0, &store_1) } + #[tokio::test] + async fn test_sqlite_store_list_all_keys() { + let mut temp_path = random_storage_path(); + temp_path.push("test_sqlite_store_list_all_keys"); + let store = SqliteStore::new( + temp_path, + Some("test_db".to_string()), + Some("test_table".to_string()), + ) + .unwrap(); + + KVStore::write(&store, "ns_a", "sub_a", "key_a", vec![1u8]).await.unwrap(); + KVStore::write(&store, "ns_a", "sub_b", "key_b", vec![2u8]).await.unwrap(); + KVStore::write(&store, "ns_b", "", "key_c", vec![3u8]).await.unwrap(); + + let mut keys = MigratableKVStore::list_all_keys(&store).await.unwrap(); + keys.sort(); + + assert_eq!( + keys, + vec![ + ("ns_a".to_string(), "sub_a".to_string(), "key_a".to_string()), + ("ns_a".to_string(), "sub_b".to_string(), "key_b".to_string()), + ("ns_b".to_string(), "".to_string(), "key_c".to_string()), + ] + ); + } + #[tokio::test] async fn test_sqlite_store_paginated_listing() { let mut temp_path = random_storage_path(); diff --git a/src/io/vss_store.rs b/src/io/vss_store.rs index f6e865bd8..52e21e047 100644 --- a/src/io/vss_store.rs +++ b/src/io/vss_store.rs @@ -24,7 +24,9 @@ use bitcoin::Network; use lightning::impl_writeable_tlv_based_enum; use lightning::io::{self, Error, ErrorKind}; use lightning::sign::{EntropySource as LdkEntropySource, RandomBytes}; -use lightning::util::persist::{KVStore, PageToken, PaginatedKVStore, PaginatedListResponse}; +use lightning::util::persist::{ + KVStore, MigratableKVStore, PageToken, PaginatedKVStore, PaginatedListResponse, +}; use lightning::util::ser::{Readable, Writeable}; use prost::Message; use vss_client::client::VssClient; @@ -321,6 +323,22 @@ impl PaginatedKVStore for VssStore { } } +impl MigratableKVStore for VssStore { + fn list_all_keys( + &self, + ) -> impl Future, io::Error>> + 'static + Send { + let inner = Arc::clone(&self.inner); + let runtime = self.internal_runtime(); + async move { + let task = runtime + .spawn(async move { inner.list_all_keys_internal(&inner.async_client).await }); + task.await.map_err(|e| { + io::Error::new(io::ErrorKind::Other, format!("VSS runtime task failed: {}", e)) + })? + } + } +} + impl Drop for VssStore { fn drop(&mut self) { if let Some(runtime) = self.internal_runtime.take() { @@ -399,7 +417,7 @@ impl VssStoreInner { } } - fn extract_key(&self, unified_key: &str) -> io::Result { + fn extract_obfuscated_key<'a>(&self, unified_key: &'a str) -> io::Result<&'a str> { let mut parts = if self.schema_version == VssSchemaVersion::V1 { let mut parts = unified_key.splitn(2, '#'); let _obfuscated_namespace = parts.next(); @@ -411,14 +429,52 @@ impl VssStoreInner { parts }; match parts.next() { - Some(obfuscated_key) => { - let actual_key = self.key_obfuscator.deobfuscate(obfuscated_key)?; - Ok(actual_key) - }, + Some(obfuscated_key) => Ok(obfuscated_key), None => Err(Error::new(ErrorKind::InvalidData, "Invalid key format")), } } + fn extract_key(&self, unified_key: &str) -> io::Result { + let obfuscated_key = self.extract_obfuscated_key(unified_key)?; + let actual_key = self.key_obfuscator.deobfuscate(obfuscated_key)?; + Ok(actual_key) + } + + fn extract_namespaces(&self, unified_key: &str) -> io::Result<(String, String)> { + if self.schema_version == VssSchemaVersion::V1 { + let mut parts = unified_key.splitn(2, '#'); + let obfuscated_namespace = parts.next(); + let _obfuscated_key = parts.next(); + match (obfuscated_namespace, _obfuscated_key) { + (Some(obfuscated_namespace), Some(_obfuscated_key)) => { + let namespace = self.key_obfuscator.deobfuscate(obfuscated_namespace)?; + let mut namespace_parts = namespace.splitn(2, '#'); + let primary_namespace = namespace_parts.next(); + let secondary_namespace = namespace_parts.next(); + match (primary_namespace, secondary_namespace) { + (Some(primary_namespace), Some(secondary_namespace)) => { + Ok((primary_namespace.to_string(), secondary_namespace.to_string())) + }, + _ => Err(Error::new(ErrorKind::InvalidData, "Invalid namespace format")), + } + }, + _ => Err(Error::new(ErrorKind::InvalidData, "Invalid key format")), + } + } else { + // Default to V0 schema. + let mut parts = unified_key.splitn(3, '#'); + let primary_namespace = parts.next(); + let secondary_namespace = parts.next(); + match (primary_namespace, secondary_namespace) { + (Some(_obfuscated_key), None) => Ok(("".to_string(), "".to_string())), + (Some(primary_namespace), Some(secondary_namespace)) => { + Ok((primary_namespace.to_string(), secondary_namespace.to_string())) + }, + _ => Err(Error::new(ErrorKind::InvalidData, "Invalid key format")), + } + } + } + async fn list_keys( &self, client: &VssClient, primary_namespace: &str, secondary_namespace: &str, key_prefix: String, page_token: Option, @@ -622,6 +678,52 @@ impl VssStoreInner { Ok(PaginatedListResponse { keys, next_page_token }) } + async fn list_all_keys_internal( + &self, client: &VssClient, + ) -> io::Result> { + let mut page_token: Option = None; + let mut keys = vec![]; + loop { + let request = ListKeyVersionsRequest { + store_id: self.store_id.clone(), + key_prefix: None, + page_token, + page_size: None, + }; + + let response = client.list_key_versions(&request).await.map_err(|e| { + let msg = format!("Failed to list all keys: {}", e); + Error::new(ErrorKind::Other, msg) + })?; + + for kv in response.key_versions { + let (primary_namespace, secondary_namespace) = self.extract_namespaces(&kv.key)?; + let key = match self.extract_key(&kv.key) { + Ok(key) => key, + Err(_) + if self.schema_version == VssSchemaVersion::V0 && !kv.key.contains('#') => + { + self.key_obfuscator.deobfuscate(&kv.key)? + }, + Err(e) => return Err(e), + }; + if primary_namespace.is_empty() + && secondary_namespace.is_empty() + && key == VSS_SCHEMA_VERSION_KEY + { + continue; + } + keys.push((primary_namespace, secondary_namespace, key)); + } + + match response.next_page_token.filter(|t| !t.is_empty()) { + Some(t) => page_token = Some(t), + None => break, + } + } + Ok(keys) + } + async fn execute_locked_write< F: Future>, FN: FnOnce() -> F, @@ -1038,6 +1140,27 @@ mod tests { drop(vss_store) } + #[tokio::test] + async fn vss_list_all_keys() { + let store = build_vss_store(); + + KVStore::write(&store, "ns_a", "sub_a", "key_a", vec![1u8]).await.unwrap(); + KVStore::write(&store, "ns_a", "sub_b", "key_b", vec![2u8]).await.unwrap(); + KVStore::write(&store, "ns_b", "", "key_c", vec![3u8]).await.unwrap(); + + let mut keys = MigratableKVStore::list_all_keys(&store).await.unwrap(); + keys.sort(); + + assert_eq!( + keys, + vec![ + ("ns_a".to_string(), "sub_a".to_string(), "key_a".to_string()), + ("ns_a".to_string(), "sub_b".to_string(), "key_b".to_string()), + ("ns_b".to_string(), "".to_string(), "key_c".to_string()), + ] + ); + } + #[tokio::test] async fn vss_paginated_listing() { let store = build_vss_store();