diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40088d65bf1..f9c105f4c4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -238,6 +238,30 @@ jobs: ./emsdk install 4.0.21 ./emsdk activate 4.0.21 + - name: Install wasm-bindgen CLI + run: | + REQUIRED_WASM_BINDGEN_VERSION="$( + awk ' + $1 == "name" && $3 == "\"wasm-bindgen\"" { in_pkg = 1; next } + in_pkg && $1 == "version" { + gsub(/"/, "", $3); + print $3; + exit; + } + ' Cargo.lock + )" + if [ -z "${REQUIRED_WASM_BINDGEN_VERSION}" ]; then + echo "Failed to determine wasm-bindgen version from Cargo.lock" + exit 1 + fi + + INSTALLED_WASM_BINDGEN_VERSION="$(wasm-bindgen --version 2>/dev/null | awk '{print $2}' || true)" + if [ "${INSTALLED_WASM_BINDGEN_VERSION}" != "${REQUIRED_WASM_BINDGEN_VERSION}" ]; then + cargo install --locked --force wasm-bindgen-cli --version "${REQUIRED_WASM_BINDGEN_VERSION}" + fi + + wasm-bindgen --version + - name: Build typescript module sdk working-directory: crates/bindings-typescript run: pnpm build diff --git a/Cargo.lock b/Cargo.lock index 95fd0e6e1c3..276fcf514b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1094,8 +1094,14 @@ name = "connect_disconnect_client" version = "2.0.5" dependencies = [ "anyhow", + "console_error_panic_hook", + "futures", + "gloo-timers", "spacetimedb-sdk", "test-counter", + "tokio", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -1111,6 +1117,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] + [[package]] name = "constant_time_eq" version = "0.3.1" @@ -2090,8 +2106,11 @@ version = "2.0.5" dependencies = [ "anyhow", "env_logger 0.10.2", + "futures", "spacetimedb-sdk", "test-counter", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -2552,6 +2571,80 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "gloo-console" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a17868f56b4a24f677b17c8cb69958385102fa879418052d60b50bc1727e261" +dependencies = [ + "gloo-utils", + "js-sys", + "serde", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-net" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06f627b1a58ca3d42b45d6104bf1e1a03799df472df00988b6ba21accc10580" +dependencies = [ + "futures-channel", + "futures-core", + "futures-sink", + "gloo-utils", + "http 1.3.1", + "js-sys", + "pin-project", + "serde", + "serde_json", + "thiserror 1.0.69", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "gloo-storage" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc8031e8c92758af912f9bc08fbbadd3c6f3cfcbf6b64cdf3d6a81f0139277a" +dependencies = [ + "gloo-utils", + "js-sys", + "serde", + "serde_json", + "thiserror 1.0.69", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-timers" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "gloo-utils" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5555354113b18c547c1d3a98fbf7fb32a9ff4f6fa112ce823a21641a0ba3aa" +dependencies = [ + "js-sys", + "serde", + "serde_json", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "gzip-header" version = "1.0.0" @@ -5611,10 +5704,13 @@ version = "2.0.5" dependencies = [ "anyhow", "env_logger 0.10.2", + "futures", "serde_json", "spacetimedb-lib 2.0.5", "spacetimedb-sdk", "test-counter", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -7693,7 +7789,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.27.0", "toml 0.8.23", "toml_edit 0.22.27", "tracing", @@ -7752,7 +7848,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-stream", - "tokio-tungstenite", + "tokio-tungstenite 0.27.0", "toml 0.8.23", "tower-http 0.5.2", "tower-layer", @@ -8306,7 +8402,7 @@ dependencies = [ "spacetimedb-lib 2.0.5", "thiserror 1.0.69", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.27.0", ] [[package]] @@ -8407,7 +8503,6 @@ dependencies = [ name = "spacetimedb-sdk" version = "2.0.5" dependencies = [ - "anymap", "base64 0.21.7", "brotli", "bytes", @@ -8415,9 +8510,15 @@ dependencies = [ "flate2", "futures", "futures-channel", + "getrandom 0.3.4", + "gloo-console", + "gloo-net", + "gloo-storage", + "gloo-utils", "hex", "home", "http 1.3.1", + "js-sys", "log", "native-tls", "once_cell", @@ -8433,7 +8534,11 @@ dependencies = [ "spacetimedb-testing", "thiserror 1.0.69", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.27.0", + "tokio-tungstenite-wasm", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", ] [[package]] @@ -9154,11 +9259,16 @@ name = "test-client" version = "2.0.5" dependencies = [ "anyhow", + "console_error_panic_hook", "env_logger 0.10.2", + "futures", + "gloo-timers", "rand 0.9.2", "spacetimedb-sdk", "test-counter", "tokio", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -9166,6 +9276,8 @@ name = "test-counter" version = "2.0.5" dependencies = [ "anyhow", + "futures", + "gloo-timers", "spacetimedb-data-structures", ] @@ -9453,6 +9565,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.26.2", +] + [[package]] name = "tokio-tungstenite" version = "0.27.0" @@ -9464,7 +9588,26 @@ dependencies = [ "native-tls", "tokio", "tokio-native-tls", - "tungstenite", + "tungstenite 0.27.0", +] + +[[package]] +name = "tokio-tungstenite-wasm" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4585aa997e4afb43c64f9101c27411b8e1bf9dde49b22e3e47acdde3055b325c" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.3.1", + "httparse", + "js-sys", + "thiserror 2.0.17", + "tokio", + "tokio-tungstenite 0.26.2", + "wasm-bindgen", + "web-sys", ] [[package]] @@ -9838,6 +9981,23 @@ dependencies = [ "termcolor", ] +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http 1.3.1", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.17", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.27.0" @@ -10111,9 +10271,12 @@ version = "2.0.5" dependencies = [ "anyhow", "env_logger 0.10.2", + "futures", "spacetimedb-lib 2.0.5", "spacetimedb-sdk", "test-counter", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] @@ -10122,8 +10285,11 @@ version = "2.0.5" dependencies = [ "anyhow", "env_logger 0.10.2", + "futures", "spacetimedb-sdk", "test-counter", + "wasm-bindgen", + "wasm-bindgen-futures", ] [[package]] diff --git a/crates/codegen/src/rust.rs b/crates/codegen/src/rust.rs index c4a24c61b62..638d6f29ec8 100644 --- a/crates/codegen/src/rust.rs +++ b/crates/codegen/src/rust.rs @@ -1661,10 +1661,11 @@ impl __sdk::InModule for RemoteTables {{ /// You must explicitly advance the connection by calling any one of: /// /// - [`DbConnection::frame_tick`]. -/// - [`DbConnection::run_threaded`]. +#[cfg_attr(not(target_arch = \"wasm32\"), doc = \"- [`DbConnection::run_threaded`].\")] +#[cfg_attr(target_arch = \"wasm32\", doc = \"- [`DbConnection::run_background_task`].\")] /// - [`DbConnection::run_async`]. /// - [`DbConnection::advance_one_message`]. -/// - [`DbConnection::advance_one_message_blocking`]. +#[cfg_attr(not(target_arch = \"wasm32\"), doc = \"- [`DbConnection::advance_one_message_blocking`].\")] /// - [`DbConnection::advance_one_message_async`]. /// /// Which of these methods you should call depends on the specific needs of your application, @@ -1762,6 +1763,7 @@ impl DbConnection {{ /// This is a low-level primitive exposed for power users who need significant control over scheduling. /// Most applications should call [`Self::run_threaded`] to spawn a thread /// which advances the connection automatically. + #[cfg(not(target_arch = \"wasm32\"))] pub fn advance_one_message_blocking(&self) -> __sdk::Result<()> {{ self.imp.advance_one_message_blocking() }} @@ -1787,10 +1789,17 @@ impl DbConnection {{ }} /// Spawn a thread which processes WebSocket messages as they are received. + #[cfg(not(target_arch = \"wasm32\"))] pub fn run_threaded(&self) -> std::thread::JoinHandle<()> {{ self.imp.run_threaded() }} + /// Spawn a background task which processes WebSocket messages as they are received. + #[cfg(target_arch = \"wasm32\")] + pub fn run_background_task(&self) {{ + self.imp.run_background_task() + }} + /// Run an `async` loop which processes WebSocket messages when polled. pub async fn run_async(&self) -> __sdk::Result<()> {{ self.imp.run_async().await diff --git a/crates/codegen/tests/snapshots/codegen__codegen_rust.snap b/crates/codegen/tests/snapshots/codegen__codegen_rust.snap index 96c5ce8afd6..200cf08e425 100644 --- a/crates/codegen/tests/snapshots/codegen__codegen_rust.snap +++ b/crates/codegen/tests/snapshots/codegen__codegen_rust.snap @@ -1391,10 +1391,17 @@ impl __sdk::InModule for RemoteTables { /// You must explicitly advance the connection by calling any one of: /// /// - [`DbConnection::frame_tick`]. -/// - [`DbConnection::run_threaded`]. +#[cfg_attr(not(target_arch = "wasm32"), doc = "- [`DbConnection::run_threaded`].")] +#[cfg_attr( + target_arch = "wasm32", + doc = "- [`DbConnection::run_background_task`]." +)] /// - [`DbConnection::run_async`]. /// - [`DbConnection::advance_one_message`]. -/// - [`DbConnection::advance_one_message_blocking`]. +#[cfg_attr( + not(target_arch = "wasm32"), + doc = "- [`DbConnection::advance_one_message_blocking`]." +)] /// - [`DbConnection::advance_one_message_async`]. /// /// Which of these methods you should call depends on the specific needs of your application, @@ -1492,6 +1499,7 @@ impl DbConnection { /// This is a low-level primitive exposed for power users who need significant control over scheduling. /// Most applications should call [`Self::run_threaded`] to spawn a thread /// which advances the connection automatically. + #[cfg(not(target_arch = "wasm32"))] pub fn advance_one_message_blocking(&self) -> __sdk::Result<()> { self.imp.advance_one_message_blocking() } @@ -1517,10 +1525,17 @@ impl DbConnection { } /// Spawn a thread which processes WebSocket messages as they are received. + #[cfg(not(target_arch = "wasm32"))] pub fn run_threaded(&self) -> std::thread::JoinHandle<()> { self.imp.run_threaded() } + /// Spawn a background task which processes WebSocket messages as they are received. + #[cfg(target_arch = "wasm32")] + pub fn run_background_task(&self) { + self.imp.run_background_task() + } + /// Run an `async` loop which processes WebSocket messages when polled. pub async fn run_async(&self) -> __sdk::Result<()> { self.imp.run_async().await diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 9e6e5642bda..090ef55020a 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -207,9 +207,12 @@ pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult { } } -// These types determine the size of [`parse_expr`]'s stack frame. +// These types determine the size of [`parse_expr`]'s stack frame on 64-bit targets. // Changing their sizes will require updating the recursion limit to avoid stack overflows. +// wasm32 has different type layouts, so this guard does not apply there. +#[cfg(target_pointer_width = "64")] const _: () = assert!(size_of::() == 168); +#[cfg(target_pointer_width = "64")] const _: () = assert!(size_of::>() == 40); /// Parse a scalar expression diff --git a/crates/testing/src/sdk.rs b/crates/testing/src/sdk.rs index e995aa28019..84e8159f28b 100644 --- a/crates/testing/src/sdk.rs +++ b/crates/testing/src/sdk.rs @@ -3,9 +3,10 @@ use rand::seq::IteratorRandom; use spacetimedb::messages::control_db::HostType; use spacetimedb_data_structures::map::HashMap; use spacetimedb_paths::{RootDir, SpacetimePaths}; -use std::fs::create_dir_all; +use std::fs::{copy, create_dir_all}; use std::sync::{Mutex, OnceLock}; use std::thread::JoinHandle; +use std::{path::Path, path::PathBuf}; use crate::invoke_cli; use crate::modules::{start_runtime, CompilationMode, CompiledModule}; @@ -106,11 +107,20 @@ pub struct Test { /// - `SPACETIME_SDK_TEST_CLIENT_PROJECT` bound to the `client_project` path. /// - `SPACETIME_SDK_TEST_DB_NAME` bound to the database identity or name. run_command: String, + + client_runner: ClientRunner, +} + +#[derive(Clone)] +enum ClientRunner { + Default, + Web { wasm_path: String, bindgen_out_dir: String }, } pub const TEST_MODULE_PROJECT_ENV_VAR: &str = "SPACETIME_SDK_TEST_MODULE_PROJECT"; pub const TEST_DB_NAME_ENV_VAR: &str = "SPACETIME_SDK_TEST_DB_NAME"; pub const TEST_CLIENT_PROJECT_ENV_VAR: &str = "SPACETIME_SDK_TEST_CLIENT_PROJECT"; +pub const TEST_RUN_SELECTOR_ENV_VAR: &str = "SPACETIME_SDK_TEST_RUN_SELECTOR"; fn language_is_unreal(language: &str) -> bool { language.eq_ignore_ascii_case("unrealcpp") @@ -139,7 +149,7 @@ impl Test { let db_name = publish_module(paths, &file, host_type); - run_client(&self.run_command, &self.client_project, &db_name); + run_client(&self.client_runner, &self.run_command, &self.client_project, &db_name); } } @@ -188,7 +198,7 @@ macro_rules! memoized { MEMOIZED .lock() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .get_or_insert_default() .entry(($($key_tuple,)*)) .or_insert_with_key(|($($key_tuple,)*)| -> $value_ty { $body }) @@ -367,8 +377,9 @@ fn split_command_string(command: &str) -> (String, Vec) { // Note: this function is memoized to ensure we only compile each client once. fn compile_client(compile_command: &str, client_project: &str) { let client_project = client_project.to_owned(); + let compile_command = compile_command.to_owned(); - memoized!(|client_project: String| -> () { + memoized!(|(client_project, compile_command): (String, String)| -> () { let (exe, args) = split_command_string(compile_command); let output = cmd(exe, args) @@ -384,24 +395,101 @@ fn compile_client(compile_command: &str, client_project: &str) { }) } -fn run_client(run_command: &str, client_project: &str, db_name: &str) { - let (exe, args) = split_command_string(run_command); - - let output = cmd(exe, args) - .dir(client_project) - .env(TEST_CLIENT_PROJECT_ENV_VAR, client_project) - .env(TEST_DB_NAME_ENV_VAR, db_name) - .env( - "RUST_LOG", - "spacetimedb=debug,spacetimedb_client_api=debug,spacetimedb_lib=debug,spacetimedb_standalone=debug", - ) - .stderr_to_stdout() - .stdout_capture() - .unchecked() - .run() - .expect("Error running run command"); - - status_ok_or_panic(output, run_command, "(running)"); +fn run_client(runner: &ClientRunner, run_command: &str, client_project: &str, db_name: &str) { + match runner { + ClientRunner::Default => { + let (exe, args) = split_command_string(run_command); + + let command = cmd(exe, args); + + let output = command + .dir(client_project) + .env(TEST_CLIENT_PROJECT_ENV_VAR, client_project) + .env(TEST_DB_NAME_ENV_VAR, db_name) + .env( + "RUST_LOG", + "spacetimedb=debug,spacetimedb_client_api=debug,spacetimedb_lib=debug,spacetimedb_standalone=debug", + ) + .stderr_to_stdout() + .stdout_capture() + .unchecked() + .run() + .expect("Error running run command"); + + status_ok_or_panic(output, run_command, "(running)"); + } + ClientRunner::Web { + wasm_path, + bindgen_out_dir, + } => { + let rust_log = + "spacetimedb=debug,spacetimedb_client_api=debug,spacetimedb_lib=debug,spacetimedb_standalone=debug"; + + let wasm_path = Path::new(wasm_path); + let bindgen_out_dir = PathBuf::from(bindgen_out_dir); + let bindgen_out_dir = if bindgen_out_dir.is_absolute() { + bindgen_out_dir + } else { + Path::new(client_project).join(bindgen_out_dir) + }; + + create_dir_all(&bindgen_out_dir).expect("Failed to create wasm-bindgen out dir"); + + let output = cmd( + "wasm-bindgen", + [ + "--target".to_owned(), + "nodejs".to_owned(), + "--out-dir".to_owned(), + bindgen_out_dir + .to_str() + .expect("bindgen_out_dir should be valid utf-8") + .to_owned(), + wasm_path.to_str().expect("wasm_path should be valid utf-8").to_owned(), + ], + ) + .dir(client_project) + .stderr_to_stdout() + .stdout_capture() + .unchecked() + .run() + .expect("Error running wasm-bindgen"); + status_ok_or_panic(output, "wasm-bindgen", "(wasm-bindgen)"); + + let js_module_name = wasm_path + .file_stem() + .expect("wasm_path should have a filename stem") + .to_str() + .expect("wasm_path stem should be valid utf-8"); + let js_module = bindgen_out_dir.join(format!("{js_module_name}.js")); + let js_module_cjs = bindgen_out_dir.join(format!("{js_module_name}.cjs")); + copy(&js_module, &js_module_cjs).expect("Failed to create .cjs wrapper for wasm-bindgen output"); + let js_module = js_module_cjs + .to_str() + .expect("js_module path should be valid utf-8") + .to_owned(); + + let node_script = format!( + "(async () => {{\n const m = require({js_module:?});\n if (m.default) {{ await m.default(); }}\n const run = m.run || m.main || m.start;\n if (!run) throw new Error('No exported run/main/start function from wasm module');\n const runSelector = process.env.{TEST_RUN_SELECTOR_ENV_VAR} ?? '';\n const dbName = process.env.{TEST_DB_NAME_ENV_VAR};\n if (!dbName) throw new Error('Missing {TEST_DB_NAME_ENV_VAR}');\n await run(runSelector, dbName);\n}})().catch((e) => {{ console.error(e); process.exit(1); }});" + ); + + let node_args: Vec = vec!["--experimental-websocket".to_owned(), "-e".to_owned(), node_script]; + + let output = cmd("node", node_args) + .dir(&bindgen_out_dir) + .env(TEST_CLIENT_PROJECT_ENV_VAR, client_project) + .env(TEST_DB_NAME_ENV_VAR, db_name) + .env(TEST_RUN_SELECTOR_ENV_VAR, run_command) + .env("RUST_LOG", rust_log) + .stderr_to_stdout() + .stdout_capture() + .unchecked() + .run() + .expect("Error running wasm client via node"); + + status_ok_or_panic(output, run_command, "(running web)"); + } + } } #[derive(Clone, Default)] @@ -414,6 +502,8 @@ pub struct TestBuilder { generate_subdir: Option, compile_command: Option, run_command: Option, + + client_runner: Option, } impl TestBuilder { @@ -474,6 +564,16 @@ impl TestBuilder { } } + pub fn with_web_client(self, wasm_path: impl Into, bindgen_out_dir: impl Into) -> Self { + TestBuilder { + client_runner: Some(ClientRunner::Web { + wasm_path: wasm_path.into(), + bindgen_out_dir: bindgen_out_dir.into(), + }), + ..self + } + } + pub fn with_generate_private_items(self, include_private: bool) -> Self { TestBuilder { generate_include_private: include_private, @@ -512,6 +612,8 @@ impl TestBuilder { run_command: self .run_command .expect("Supply a run command using TestBuilder::with_run_command"), + + client_runner: self.client_runner.unwrap_or(ClientRunner::Default), } } } diff --git a/diff-explain.txt b/diff-explain.txt new file mode 100644 index 00000000000..d629c36c02c --- /dev/null +++ b/diff-explain.txt @@ -0,0 +1,48 @@ + Mechanical only: + + - sdks/rust/tests/test.rs:1 is mostly runner/config wiring so the same declared tests can execute either the native client or the wasm client. + - The bulk of the async churn in sdks/rust/tests/test-client/src/main.rs:41 is helper-level: shared dispatch, shared connect*, shared wait_for_all, wasm/native shims. + - sdks/rust/tests/connect_disconnect_client/src/main.rs:1 was similarly moved onto the shared async harness path, but its assertions did not change. + + The test-body changes that are actually substantive are these: + + - sdks/rust/tests/test-client/src/main.rs:1045 and sdks/rust/tests/test-client/src/main.rs:1347 + Reason: Timestamp::now() is stubbed/unimplemented on wasm32-unknown-unknown. + Why the change is acceptable: those tests are not trying to validate wall-clock acquisition; they are validating timestamp serialization/round-trip through reducers/subscriptions. A fixed timestamp preserves that assertion and + removes a wasm-only runtime failure. + - sdks/rust/tests/test-client/src/main.rs:2179 + Reason 1: it now subscribes only to one_u_8 and one_u_16 instead of subscribe_all. + Reason 2: it now marks the subscription as “ready” only after the relevant callbacks are installed. + Why the change is needed: this test only asserts on those two tables, so subscribe_all was unnecessary setup cost. More importantly, the old ordering could report “subscription applied” before the handlers under test were + actually registered, which is a real race on wasm. + + If you are looking at my current local changes on top of that branch, there are a few additional meaningful test-body fixes: + + - sdks/rust/tests/test-client/src/main.rs:2729 + Reason: it registered TestCounter expectations but never waited for them. + Why the change is needed: the old version could fall off the end early on native and became an outright hang on wasm. Adding the wait does not broaden the test; it makes the existing assertion actually execute. + - sdks/rust/tests/test-client/src/main.rs:2821 + Reason: it used a blocking Barrier inside callback-driven code. + Why the change is needed: that deadlocks wasm, because the callbacks and the wait share the same JS event loop. The atomic gate preserves the intended semantics: insert exactly once, only after all three subscriptions are + ready. + - sdks/rust/tests/test-client/src/main.rs:2919 + Reason: it used a second TestCounter as a cross-client blocking barrier. + Why the change is needed: same wasm deadlock issue. The test’s actual assertion is per-client filtering after that client’s own subscription applies; it does not require both clients to reach a global barrier first. + - sdks/rust/tests/test-client/src/main.rs:2977 + Reason: same pattern as test_parameterized_subscription. + Why the change is needed: the barrier was not part of the behavior under test, only a synchronization convenience that is invalid on wasm. Removing it keeps the test focused on RLS visibility rather than cross-client + simultaneity. + + So the short version is: + + - Most of the scary-looking diff is harness plumbing, not changed assertions. + - The real test-content changes were either: + - replacing native-only time acquisition with a deterministic value, or + - removing synchronization patterns that were accidentally relying on native’s threaded runtime and were invalid on wasm. + + If you want, I can turn this into a reviewer-oriented checklist grouped as: + + 1. pure harness changes, + 2. assertion-preserving test fixes, + 3. true behavioral test changes. + diff --git a/sdks/rust/Cargo.toml b/sdks/rust/Cargo.toml index f75dc817631..92d9781982e 100644 --- a/sdks/rust/Cargo.toml +++ b/sdks/rust/Cargo.toml @@ -9,7 +9,22 @@ rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] +default = [] allow_loopback_http_for_tests = ["spacetimedb-testing/allow_loopback_http_for_tests"] +# Run SDK integration tests with wasm+web test clients instead of native test clients. +sdk-tests-web-client = [] +web = [ + "dep:getrandom", + "dep:gloo-console", + "dep:gloo-net", + "dep:gloo-storage", + "dep:gloo-utils", + "dep:js-sys", + "dep:tokio-tungstenite-wasm", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:web-sys", +] [dependencies] spacetimedb-data-structures.workspace = true @@ -21,19 +36,31 @@ spacetimedb-query-builder.workspace = true spacetimedb-schema.workspace = true thiserror.workspace = true -anymap.workspace = true base64.workspace = true brotli.workspace = true bytes.workspace = true flate2.workspace = true futures.workspace = true futures-channel.workspace = true -home.workspace = true http.workspace = true log.workspace = true once_cell.workspace = true prometheus.workspace = true rand.workspace = true + +getrandom = { version = "0.3.4", features = ["wasm_js"], optional = true } +gloo-console = { version = "0.3.0", optional = true } +gloo-net = { version = "0.6.0", optional = true } +gloo-storage = { version = "0.3.0", optional = true } +gloo-utils = { version = "0.2.0", optional = true } +js-sys = { version = "0.3", optional = true } +tokio-tungstenite-wasm = { version = "0.6.0", optional = true } +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } +web-sys = { version = "0.3.77", features = ["HtmlDocument"], optional = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +home.workspace = true tokio.workspace = true tokio-tungstenite.workspace = true # native-tls 0.2.17 fails to compile with Rust 1.93.0 due non-exhaustive diff --git a/sdks/rust/src/client_cache.rs b/sdks/rust/src/client_cache.rs index d9c4c97c7f6..e72fb81594c 100644 --- a/sdks/rust/src/client_cache.rs +++ b/sdks/rust/src/client_cache.rs @@ -5,7 +5,6 @@ use crate::callbacks::CallbackId; use crate::db_connection::{debug_log, PendingMutation, SharedCell}; use crate::spacetime_module::{InModule, SpacetimeModule, TableUpdate, WithBsatn}; -use anymap::{any::Any, Map}; use bytes::Bytes; use core::any::type_name; use core::hash::Hash; @@ -16,6 +15,10 @@ use std::fs::File; use std::io::Write; use std::marker::PhantomData; use std::sync::Arc; +use std::{ + any::{Any, TypeId}, + boxed::Box, +}; /// A local mirror of the subscribed rows of one table in the database. pub struct TableCache { @@ -347,6 +350,36 @@ pub struct ClientCache { _module: PhantomData, } +// We intentionally avoid `anymap` here. +// +// In wasm test-client runs (`wasm32-unknown-unknown` under Node), +// `anymap`'s TypeId hasher path can trigger an alignment-UB check panic: +// `ptr::copy_nonoverlapping requires aligned pointers`. +// Using this local `TypeId -> Box` map preserves the +// required functionality without that runtime failure. +#[derive(Default)] +struct TypeMap { + values: HashMap>, +} + +impl TypeMap { + fn get(&self) -> Option<&T> { + self.values + .get(&TypeId::of::()) + .and_then(|value| value.downcast_ref::()) + } + + fn get_or_insert_with(&mut self, make: impl FnOnce() -> T) -> &mut T { + let value = match self.values.entry(TypeId::of::()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Box::new(make())), + }; + value + .downcast_mut::() + .expect("TypeMap entry did not match stored TypeId") + } +} + impl ClientCache { pub(crate) fn new(extra_logging: Option>) -> Self { Self { @@ -373,8 +406,7 @@ impl ClientCache { table_name: &'static str, ) -> &mut TableCache { self.tables - .entry::>>() - .or_insert_with(Default::default) + .get_or_insert_with::>>(Default::default) .entry(table_name) .or_insert_with(|| TableCache::new(self.extra_logging.clone())) } diff --git a/sdks/rust/src/credentials.rs b/sdks/rust/src/credentials.rs index bdef761048c..7174e8fea15 100644 --- a/sdks/rust/src/credentials.rs +++ b/sdks/rust/src/credentials.rs @@ -8,144 +8,305 @@ //! } //! ``` -use home::home_dir; -use spacetimedb_lib::{bsatn, de::Deserialize, ser::Serialize}; -use std::path::PathBuf; -use thiserror::Error; - -const CREDENTIALS_DIR: &str = ".spacetimedb_client_credentials"; - -#[derive(Error, Debug)] -pub enum CredentialFileError { - #[error("Failed to determine user home directory as root for credentials storage")] - DetermineHomeDir, - #[error("Error creating credential storage directory {path}")] - CreateDir { - path: PathBuf, - #[source] - source: std::io::Error, - }, - #[error("Error serializing credentials for storage in file")] - Serialize { - #[source] - source: bsatn::EncodeError, - }, - #[error("Error writing BSATN-serialized credentials to file {path}")] - Write { - path: PathBuf, - #[source] - source: std::io::Error, - }, - #[error("Error reading BSATN-serialized credentials from file {path}")] - Read { - path: PathBuf, - #[source] - source: std::io::Error, - }, - #[error("Error deserializing credentials from bytes stored in file {path}")] - Deserialize { - path: PathBuf, - #[source] - source: bsatn::DecodeError, - }, -} +#[cfg(not(feature = "web"))] +mod native_mod { + use home::home_dir; + use spacetimedb_lib::{bsatn, de::Deserialize, ser::Serialize}; + use std::path::PathBuf; + use thiserror::Error; -/// A file on disk which stores, or can store, a JWT for authenticating with SpacetimeDB. -/// -/// The file does not necessarily exist or store credentials. -/// If the credentials have been stored previously, they can be accessed with [`File::load`]. -/// New credentials can be saved to disk with [`File::save`]. -pub struct File { - filename: String, -} + const CREDENTIALS_DIR: &str = ".spacetimedb_client_credentials"; -#[derive(Serialize, Deserialize)] -struct Credentials { - token: String, -} + #[derive(Error, Debug)] + pub enum CredentialFileError { + #[error("Failed to determine user home directory as root for credentials storage")] + DetermineHomeDir, + #[error("Error creating credential storage directory {path}")] + CreateDir { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("Error serializing credentials for storage in file")] + Serialize { + #[source] + source: bsatn::EncodeError, + }, + #[error("Error writing BSATN-serialized credentials to file {path}")] + Write { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("Error reading BSATN-serialized credentials from file {path}")] + Read { + path: PathBuf, + #[source] + source: std::io::Error, + }, + #[error("Error deserializing credentials from bytes stored in file {path}")] + Deserialize { + path: PathBuf, + #[source] + source: bsatn::DecodeError, + }, + } -impl File { - /// Get a handle on a file which stores a SpacetimeDB [`Identity`] and its private access token. - /// - /// This method does not create the file or check that it exists. - /// - /// Distinct applications running as the same user on the same machine - /// may share [`Identity`]/token pairs by supplying the same `key`. - /// Users who desire distinct credentials for their application - /// should supply a unique `key` per application. + /// A file on disk which stores, or can store, a JWT for authenticating with SpacetimeDB. /// - /// No additional namespacing is provided to tie the stored token - /// to a particular SpacetimeDB instance or cluster. - /// Users who intend to connect to multiple instances or clusters - /// should use a distinct `key` per cluster. - pub fn new(key: impl Into) -> Self { - Self { filename: key.into() } + /// The file does not necessarily exist or store credentials. + /// If the credentials have been stored previously, they can be accessed with [`File::load`]. + /// New credentials can be saved to disk with [`File::save`]. + pub struct File { + filename: String, } - fn determine_home_dir() -> Result { - home_dir().ok_or(CredentialFileError::DetermineHomeDir) + #[derive(Serialize, Deserialize)] + struct Credentials { + token: String, } - fn ensure_credentials_dir() -> Result<(), CredentialFileError> { - let mut path = Self::determine_home_dir()?; - path.push(CREDENTIALS_DIR); + impl File { + /// Get a handle on a file which stores a SpacetimeDB [`Identity`] and its private access token. + /// + /// This method does not create the file or check that it exists. + /// + /// Distinct applications running as the same user on the same machine + /// may share [`Identity`]/token pairs by supplying the same `key`. + /// Users who desire distinct credentials for their application + /// should supply a unique `key` per application. + /// + /// No additional namespacing is provided to tie the stored token + /// to a particular SpacetimeDB instance or cluster. + /// Users who intend to connect to multiple instances or clusters + /// should use a distinct `key` per cluster. + pub fn new(key: impl Into) -> Self { + Self { filename: key.into() } + } - std::fs::create_dir_all(&path).map_err(|source| CredentialFileError::CreateDir { path, source }) - } + fn determine_home_dir() -> Result { + home_dir().ok_or(CredentialFileError::DetermineHomeDir) + } - fn path(&self) -> Result { - let mut path = Self::determine_home_dir()?; - path.push(CREDENTIALS_DIR); - path.push(&self.filename); - Ok(path) - } + fn ensure_credentials_dir() -> Result<(), CredentialFileError> { + let mut path = Self::determine_home_dir()?; + path.push(CREDENTIALS_DIR); - /// Store the provided `token` to disk in the file referred to by `self`. - /// - /// Future calls to [`Self::load`] on a `File` with the same key can retrieve the token. - /// - /// Expected usage is to call this from a [`super::DbConnectionBuilder::on_connect`] callback. - /// - /// ```ignore - /// DbConnection::builder() - /// .on_connect(|_ctx, _identity, token| { - /// credentials::File::new("my_app").save(token).unwrap(); - /// }) - /// ``` - pub fn save(self, token: impl Into) -> Result<(), CredentialFileError> { - Self::ensure_credentials_dir()?; - - let creds = bsatn::to_vec(&Credentials { token: token.into() }) - .map_err(|source| CredentialFileError::Serialize { source })?; - let path = self.path()?; - std::fs::write(&path, creds).map_err(|source| CredentialFileError::Write { path, source })?; - Ok(()) + std::fs::create_dir_all(&path).map_err(|source| CredentialFileError::CreateDir { path, source }) + } + + fn path(&self) -> Result { + let mut path = Self::determine_home_dir()?; + path.push(CREDENTIALS_DIR); + path.push(&self.filename); + Ok(path) + } + + /// Store the provided `token` to disk in the file referred to by `self`. + /// + /// Future calls to [`Self::load`] on a `File` with the same key can retrieve the token. + /// + /// Expected usage is to call this from a [`super::DbConnectionBuilder::on_connect`] callback. + /// + /// ```ignore + /// DbConnection::builder() + /// .on_connect(|_ctx, _identity, token| { + /// credentials::File::new("my_app").save(token).unwrap(); + /// }) + /// ``` + pub fn save(self, token: impl Into) -> Result<(), CredentialFileError> { + Self::ensure_credentials_dir()?; + + let creds = bsatn::to_vec(&Credentials { token: token.into() }) + .map_err(|source| CredentialFileError::Serialize { source })?; + let path = self.path()?; + std::fs::write(&path, creds).map_err(|source| CredentialFileError::Write { path, source })?; + Ok(()) + } + + /// Load a saved token from disk in the file referred to by `self`, + /// if they have previously been stored by [`Self::save`]. + /// + /// Returns `Err` if I/O fails, + /// `None` if credentials have not previously been stored, + /// or `Some` if credentials are successfully loaded from disk. + /// After unwrapping the `Result`, the returned `Option` can be passed to + /// [`super::DbConnectionBuilder::with_token`]. + /// + /// ```ignore + /// DbConnection::builder() + /// .with_token(credentials::File::new("my_app").load().unwrap()) + /// ``` + pub fn load(self) -> Result, CredentialFileError> { + let path = self.path()?; + + let bytes = match std::fs::read(&path) { + Ok(bytes) => bytes, + Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => return Ok(None), + Err(source) => return Err(CredentialFileError::Read { path, source }), + }; + + let creds = bsatn::from_slice::(&bytes) + .map_err(|source| CredentialFileError::Deserialize { path, source })?; + Ok(Some(creds.token)) + } } +} - /// Load a saved token from disk in the file referred to by `self`, - /// if they have previously been stored by [`Self::save`]. - /// - /// Returns `Err` if I/O fails, - /// `None` if credentials have not previously been stored, - /// or `Some` if credentials are successfully loaded from disk. - /// After unwrapping the `Result`, the returned `Option` can be passed to - /// [`super::DbConnectionBuilder::with_token`]. - /// - /// ```ignore - /// DbConnection::builder() - /// .with_token(credentials::File::new("my_app").load().unwrap()) - /// ``` - pub fn load(self) -> Result, CredentialFileError> { - let path = self.path()?; - - let bytes = match std::fs::read(&path) { - Ok(bytes) => bytes, - Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => return Ok(None), - Err(source) => return Err(CredentialFileError::Read { path, source }), - }; - - let creds = bsatn::from_slice::(&bytes) - .map_err(|source| CredentialFileError::Deserialize { path, source })?; - Ok(Some(creds.token)) +#[cfg(feature = "web")] +mod web_mod { + pub use gloo_storage::{LocalStorage, SessionStorage, Storage}; + + pub mod cookies { + use thiserror::Error; + use wasm_bindgen::{JsCast, JsValue}; + use web_sys::HtmlDocument; + + #[derive(Error, Debug)] + pub enum CookieError { + #[error("Error reading cookies: {0:?}")] + Get(JsValue), + + #[error("Error setting cookie `{key}`: {js_value:?}")] + Set { key: String, js_value: JsValue }, + } + + /// A builder for constructing and setting cookies. + pub struct Cookie { + name: String, + value: String, + path: Option, + domain: Option, + max_age: Option, + secure: bool, + same_site: Option, + } + + impl Cookie { + /// Creates a new cookie builder with the specified name and value. + pub fn new(name: impl Into, value: impl Into) -> Self { + Self { + name: name.into(), + value: value.into(), + path: None, + domain: None, + max_age: None, + secure: false, + same_site: None, + } + } + + /// Gets the value of a cookie by name. + pub fn get(name: &str) -> Result, CookieError> { + let doc = get_html_document(); + let all = doc.cookie().map_err(CookieError::Get)?; + for cookie in all.split(';') { + let cookie = cookie.trim(); + if let Some((k, v)) = cookie.split_once('=') { + if k == name { + return Ok(Some(v.to_string())); + } + } + } + + Ok(None) + } + + /// Sets the `Path` attribute (e.g., "/"). + pub fn path(mut self, path: impl Into) -> Self { + self.path = Some(path.into()); + self + } + + /// Sets the `Domain` attribute (e.g., "example.com"). + pub fn domain(mut self, domain: impl Into) -> Self { + self.domain = Some(domain.into()); + self + } + + /// Sets the `Max-Age` attribute in seconds. + pub fn max_age(mut self, seconds: i32) -> Self { + self.max_age = Some(seconds); + self + } + + /// Toggles the `Secure` flag. + /// Defaults to `false`. + pub fn secure(mut self, enabled: bool) -> Self { + self.secure = enabled; + self + } + + /// Sets the `SameSite` attribute (`Strict`, `Lax`, or `None`). + pub fn same_site(mut self, same_site: SameSite) -> Self { + self.same_site = Some(same_site); + self + } + + pub fn set(self) -> Result<(), CookieError> { + let doc = get_html_document(); + let mut parts = vec![format!("{}={}", self.name, self.value)]; + + if let Some(path) = self.path { + parts.push(format!("Path={path}")); + } + if let Some(domain) = self.domain { + parts.push(format!("Domain={domain}")); + } + if let Some(age) = self.max_age { + parts.push(format!("Max-Age={age}")); + } + if self.secure { + parts.push("Secure".into()); + } + if let Some(same) = self.same_site { + parts.push(format!("SameSite={same}")); + } + + let cookie_str = parts.join("; "); + doc.set_cookie(&cookie_str).map_err(|e| CookieError::Set { + key: self.name.clone(), + js_value: e, + }) + } + + /// Deletes the cookie by setting its value to empty and `Max-Age=0`. + pub fn delete(self) -> Result<(), CookieError> { + self.value("").max_age(0).set() + } + + /// Helper to override value for delete + fn value(mut self, value: impl Into) -> Self { + self.value = value.into(); + self + } + } + + /// Controls the `SameSite` attribute for cookies. + pub enum SameSite { + Strict, + Lax, + None, + } + + impl std::fmt::Display for SameSite { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SameSite::Strict => f.write_str("Strict"), + SameSite::Lax => f.write_str("Lax"), + SameSite::None => f.write_str("None"), + } + } + } + + fn get_html_document() -> HtmlDocument { + gloo_utils::document().unchecked_into::() + } } } + +#[cfg(not(feature = "web"))] +pub use native_mod::*; + +#[cfg(feature = "web")] +pub use web_mod::*; diff --git a/sdks/rust/src/db_connection.rs b/sdks/rust/src/db_connection.rs index 60351f4d9fe..d0d2b8a19fe 100644 --- a/sdks/rust/src/db_connection.rs +++ b/sdks/rust/src/db_connection.rs @@ -32,17 +32,15 @@ use crate::{ }; use bytes::Bytes; use futures::StreamExt; +#[cfg(feature = "web")] +use futures::{pin_mut, FutureExt}; use futures_channel::mpsc; use http::Uri; use spacetimedb_client_api_messages::websocket::{self as ws, common::QuerySetId}; use spacetimedb_lib::{bsatn, ser::Serialize, ConnectionId, Identity, Timestamp}; use spacetimedb_sats::Deserialize; -use std::{ - fs::{File, OpenOptions}, - io::Write, - path::PathBuf, - sync::{atomic::AtomicU32, Arc, Mutex as StdMutex, OnceLock}, -}; +use std::sync::{atomic::AtomicU32, Arc, Mutex as StdMutex, OnceLock}; +#[cfg(not(feature = "web"))] use tokio::{ runtime::{self, Runtime}, sync::Mutex as TokioMutex, @@ -50,12 +48,18 @@ use tokio::{ pub(crate) type SharedCell = Arc>; +#[cfg(not(feature = "web"))] +type SharedAsyncCell = Arc>; +#[cfg(feature = "web")] +type SharedAsyncCell = SharedCell; + /// Implementation of `DbConnection`, `EventContext`, /// and anything else that provides access to the database connection. /// /// This must be relatively cheaply `Clone`-able, and have internal sharing, /// as numerous operations will clone it to get new handles on the connection. pub struct DbContextImpl { + #[cfg(not(feature = "web"))] runtime: runtime::Handle, /// All the state which is safe to hold a lock on while running callbacks. @@ -69,7 +73,7 @@ pub struct DbContextImpl { /// Receiver channel for WebSocket messages, /// which are pre-parsed in the background by [`parse_loop`]. - recv: Arc>>>, + recv: SharedAsyncCell>>, /// Channel into which operations which apparently mutate SDK state, /// e.g. registering callbacks, push [`PendingMutation`] messages, @@ -79,7 +83,7 @@ pub struct DbContextImpl { /// Receive end of `pending_mutations_send`, /// from which [Self::apply_pending_mutations] and friends read mutations. - pending_mutations_recv: Arc>>>, + pending_mutations_recv: SharedAsyncCell>>, /// This connection's `Identity`. /// @@ -98,6 +102,7 @@ pub struct DbContextImpl { impl Clone for DbContextImpl { fn clone(&self) -> Self { Self { + #[cfg(not(feature = "web"))] runtime: self.runtime.clone(), // Being very explicit with `Arc::clone` here, // since we'll be doing `DbContextImpl::clone` very frequently, @@ -320,9 +325,10 @@ impl DbContextImpl { /// Apply all queued [`PendingMutation`]s. fn apply_pending_mutations(&self) -> crate::Result<()> { - while let Ok(Some(pending_mutation)) = self.pending_mutations_recv.blocking_lock().try_next() { + while let Ok(Some(pending_mutation)) = get_lock_sync(&self.pending_mutations_recv).try_next() { self.apply_mutation(pending_mutation)?; } + Ok(()) } @@ -526,7 +532,7 @@ impl DbContextImpl { // returns `Err(_)`. Similar behavior as `Iterator::next` and // `Stream::poll_next`. No comment on whether this is a good mental // model or not. - let res = match self.recv.blocking_lock().try_next() { + let res = match get_lock_sync(&self.recv).try_next() { Ok(None) => { let disconnect_ctx = self.make_event_ctx(None); self.invoke_disconnected(&disconnect_ctx); @@ -549,8 +555,8 @@ impl DbContextImpl { // We call this out as an incorrect and unsupported thing to do. #![allow(clippy::await_holding_lock)] - let mut pending_mutations = self.pending_mutations_recv.lock().await; - let mut recv = self.recv.lock().await; + let mut pending_mutations = get_lock_async(&self.pending_mutations_recv).await; + let mut recv = get_lock_async(&self.recv).await; // Always process pending mutations before WS messages, if they're available, // so that newly registered callbacks run on messages. @@ -561,15 +567,28 @@ impl DbContextImpl { return Message::Local(pending_mutation.unwrap()); } + #[cfg(not(feature = "web"))] tokio::select! { pending_mutation = pending_mutations.next() => Message::Local(pending_mutation.unwrap()), incoming_message = recv.next() => Message::Ws(incoming_message), } + + #[cfg(feature = "web")] + { + let (pending_fut, recv_fut) = (pending_mutations.next().fuse(), recv.next().fuse()); + pin_mut!(pending_fut, recv_fut); + + futures::select! { + pending_mutation = pending_fut => Message::Local(pending_mutation.unwrap()), + incoming_message = recv_fut => Message::Ws(incoming_message), + } + } } /// Like [`Self::advance_one_message`], but sleeps the thread until a message is available. /// /// Called by the autogenerated `DbConnection` method of the same name. + #[cfg(not(feature = "web"))] pub fn advance_one_message_blocking(&self) -> crate::Result<()> { match self.runtime.block_on(self.get_message()) { Message::Local(pending) => self.apply_mutation(pending), @@ -608,6 +627,7 @@ impl DbContextImpl { /// Spawn a thread which does [`Self::advance_one_message_blocking`] in a loop. /// /// Called by the autogenerated `DbConnection` method of the same name. + #[cfg(not(feature = "web"))] pub fn run_threaded(&self) -> std::thread::JoinHandle<()> { let this = self.clone(); std::thread::spawn(move || loop { @@ -619,6 +639,23 @@ impl DbContextImpl { }) } + /// Spawn a background task which does [`Self::advance_one_message_async`] in a loop. + /// + /// Called by the autogenerated `DbConnection` method of the same name. + #[cfg(feature = "web")] + pub fn run_background_task(&self) { + let this = self.clone(); + wasm_bindgen_futures::spawn_local(async move { + loop { + match this.advance_one_message_async().await { + Ok(()) => (), + Err(e) if error_is_normal_disconnect(&e) => return, + Err(e) => panic!("{e:?}"), + } + } + }) + } + /// An async task which does [`Self::advance_one_message_async`] in a loop. /// /// Called by the autogenerated `DbConnection` method of the same name. @@ -745,6 +782,7 @@ pub(crate) struct DbContextImplInner { /// `Some` if not within the context of an outer runtime. The `Runtime` must /// then live as long as `Self`. #[allow(unused)] + #[cfg(not(feature = "web"))] runtime: Option, db_callbacks: DbCallbacks, @@ -857,6 +895,7 @@ You must explicitly advance the connection by calling any one of: - `DbConnection::frame_tick`. - `DbConnection::run_threaded`. +- `DbConnection::run_background_task`. - `DbConnection::run_async`. - `DbConnection::advance_one_message`. - `DbConnection::advance_one_message_blocking`. @@ -865,13 +904,21 @@ You must explicitly advance the connection by calling any one of: Which of these methods you should call depends on the specific needs of your application, but you must call one of them, or else the connection will never progress. "] + #[cfg(not(feature = "web"))] pub fn build(self) -> crate::Result { let imp = self.build_impl()?; Ok(::new(imp)) } + #[cfg(feature = "web")] + pub async fn build(self) -> crate::Result { + let imp = self.build_impl().await?; + Ok(::new(imp)) + } + /// Open a WebSocket connection, build an empty client cache, &c, /// to construct a [`DbContextImpl`]. + #[cfg(not(feature = "web"))] fn build_impl(self) -> crate::Result> { let extra_logging = self .additional_logging_path @@ -884,9 +931,6 @@ but you must call one of them, or else the connection will never progress. .map(|file| Arc::new(StdMutex::new(file))); let (runtime, handle) = enter_or_create_runtime()?; - let db_callbacks = DbCallbacks::default(); - let reducer_callbacks = ReducerCallbacks::default(); - let procedure_callbacks = ProcedureCallbacks::default(); let connection_id_override = get_connection_id_override(); let ws_connection = tokio::task::block_in_place(|| { @@ -902,44 +946,58 @@ but you must call one of them, or else the connection will never progress. source: InternalError::new("Failed to initiate WebSocket connection").with_cause(source), })?; - let (_websocket_loop_handle, raw_msg_recv, raw_msg_send) = - ws_connection.spawn_message_loop(&handle, extra_logging.clone()); - let (_parse_loop_handle, parsed_recv_chan) = - spawn_parse_loop::(raw_msg_recv, &handle, extra_logging.clone()); + let (_websocket_loop_handle, raw_msg_recv, raw_msg_send) = ws_connection.spawn_message_loop(&handle); + let (_parse_loop_handle, parsed_recv_chan) = spawn_parse_loop::(raw_msg_recv, &handle); + let parsed_recv_chan = Arc::new(TokioMutex::new(parsed_recv_chan)); - let inner = Arc::new(StdMutex::new(DbContextImplInner { - runtime, - - db_callbacks, - reducer_callbacks, - subscriptions: SubscriptionManager::default(), + let (pending_mutations_send, pending_mutations_recv) = mpsc::unbounded(); + let pending_mutations_recv = Arc::new(TokioMutex::new(pending_mutations_recv)); + + let inner_ctx = build_db_ctx_inner(runtime, self.on_connect, self.on_connect_error, self.on_disconnect); + Ok(build_db_ctx( + handle, + inner_ctx, + raw_msg_send, + parsed_recv_chan, + pending_mutations_send, + pending_mutations_recv, + connection_id_override, + )) + } - on_connect: self.on_connect, - on_connect_error: self.on_connect_error, - on_disconnect: self.on_disconnect, - procedure_callbacks, - })); + /// Open a WebSocket connection, build an empty client cache, &c, + /// to construct a [`DbContextImpl`]. + #[cfg(feature = "web")] + async fn build_impl(self) -> crate::Result> { + let connection_id_override = get_connection_id_override(); + let ws_connection = WsConnection::connect( + self.uri.clone().unwrap(), + self.database_name.as_ref().unwrap(), + self.token.as_deref(), + connection_id_override, + self.params, + ) + .await + .map_err(|source| crate::Error::FailedToConnect { + source: InternalError::new("Failed to initiate WebSocket connection").with_cause(source), + })?; - let mut cache = ClientCache::new(extra_logging.clone()); - M::register_tables(&mut cache); - let cache = Arc::new(StdMutex::new(cache)); - let send_chan = Arc::new(StdMutex::new(Some(raw_msg_send))); + let (raw_msg_recv, raw_msg_send) = ws_connection.spawn_message_loop(); + let parsed_recv_chan = spawn_parse_loop::(raw_msg_recv); + let parsed_recv_chan = Arc::new(StdMutex::new(parsed_recv_chan)); let (pending_mutations_send, pending_mutations_recv) = mpsc::unbounded(); - let ctx_imp = DbContextImpl { - runtime: handle, - inner, - send_chan, - cache, - recv: Arc::new(TokioMutex::new(parsed_recv_chan)), - pending_mutations_send, - pending_mutations_recv: Arc::new(TokioMutex::new(pending_mutations_recv)), - identity: Arc::new(StdMutex::new(None)), - connection_id: Arc::new(StdMutex::new(connection_id_override)), - extra_logging, - }; + let pending_mutations_recv = Arc::new(StdMutex::new(pending_mutations_recv)); - Ok(ctx_imp) + let inner_ctx = build_db_ctx_inner(self.on_connect, self.on_connect_error, self.on_disconnect); + Ok(build_db_ctx( + inner_ctx, + raw_msg_send, + parsed_recv_chan, + pending_mutations_send, + pending_mutations_recv, + connection_id_override, + )) } /// Set the URI of the SpacetimeDB host which is running the remote database. @@ -1074,9 +1132,63 @@ Instead of registering multiple `on_disconnect` callbacks, register a single cal } } +/// Create a [`DbContextImplInner`] wrapped in `Arc>`. +fn build_db_ctx_inner( + #[cfg(not(feature = "web"))] runtime: Option, + + on_connect_cb: Option>, + on_connect_error_cb: Option>, + on_disconnect_cb: Option>, +) -> Arc>> { + Arc::new(StdMutex::new(DbContextImplInner { + #[cfg(not(feature = "web"))] + runtime, + + db_callbacks: DbCallbacks::default(), + reducer_callbacks: ReducerCallbacks::default(), + subscriptions: SubscriptionManager::default(), + + on_connect: on_connect_cb, + on_connect_error: on_connect_error_cb, + on_disconnect: on_disconnect_cb, + + procedure_callbacks: ProcedureCallbacks::default(), + })) +} + +/// Assemble and return a [`DbContextImpl`] from the provided [`DbContextImplInner`], and channels. +fn build_db_ctx( + #[cfg(not(feature = "web"))] runtime_handle: runtime::Handle, + + inner_ctx: Arc>>, + raw_msg_send: mpsc::UnboundedSender, + parsed_msg_recv: SharedAsyncCell>>, + pending_mutations_send: mpsc::UnboundedSender>, + pending_mutations_recv: SharedAsyncCell>>, + connection_id: Option, +) -> DbContextImpl { + let mut cache = ClientCache::default(); + M::register_tables(&mut cache); + let cache = Arc::new(StdMutex::new(cache)); + + DbContextImpl { + #[cfg(not(feature = "web"))] + runtime: runtime_handle, + inner: inner_ctx, + send_chan: Arc::new(StdMutex::new(Some(raw_msg_send))), + cache, + recv: parsed_msg_recv, + pending_mutations_send, + pending_mutations_recv, + identity: Arc::new(StdMutex::new(None)), + connection_id: Arc::new(StdMutex::new(connection_id)), + } +} + // When called from within an async context, return a handle to it (and no // `Runtime`), otherwise create a fresh `Runtime` and return it along with a // handle to it. +#[cfg(not(feature = "web"))] fn enter_or_create_runtime() -> crate::Result<(Option, runtime::Handle)> { match runtime::Handle::try_current() { Err(e) if e.is_missing_context() => { @@ -1099,7 +1211,31 @@ fn enter_or_create_runtime() -> crate::Result<(Option, runtime::Handle) } } -#[derive(Debug)] +/// Synchronous lock helper: native = blocking_lock, web = lock().unwrap() +#[cfg(not(feature = "web"))] +fn get_lock_sync(mutex: &TokioMutex) -> tokio::sync::MutexGuard<'_, T> { + mutex.blocking_lock() +} + +/// Synchronous lock helper: native = blocking_lock, web = lock().unwrap() +#[cfg(feature = "web")] +fn get_lock_sync(mutex: &StdMutex) -> std::sync::MutexGuard<'_, T> { + mutex.lock().unwrap() +} + +// Async‐lock helper: native = .lock().await, web = lock().unwrap() inside async fn +#[cfg(not(feature = "web"))] +async fn get_lock_async(mutex: &TokioMutex) -> tokio::sync::MutexGuard<'_, T> { + mutex.lock().await +} + +// Async‐lock helper: native = .lock().await, web = lock().unwrap() inside async fn +#[cfg(feature = "web")] +pub async fn get_lock_async(mutex: &StdMutex) -> std::sync::MutexGuard<'_, T> { + // still async, but does the sync lock immediately + mutex.lock().unwrap() +} + enum ParsedMessage { TransactionUpdate(M::DbUpdate), IdentityToken(Identity, Box, ConnectionId), @@ -1127,6 +1263,7 @@ enum ParsedMessage { }, } +#[cfg(not(feature = "web"))] fn spawn_parse_loop( raw_message_recv: mpsc::UnboundedReceiver, handle: &runtime::Handle, @@ -1137,6 +1274,15 @@ fn spawn_parse_loop( (handle, parsed_message_recv) } +#[cfg(feature = "web")] +fn spawn_parse_loop( + raw_message_recv: mpsc::UnboundedReceiver, +) -> mpsc::UnboundedReceiver> { + let (parsed_message_send, parsed_message_recv) = mpsc::unbounded(); + wasm_bindgen_futures::spawn_local(parse_loop(raw_message_recv, parsed_message_send)); + parsed_message_recv +} + /// A loop which reads raw WS messages from `recv`, parses them into domain types, /// and pushes the [`ParsedMessage`]s into `send`. async fn parse_loop( diff --git a/sdks/rust/src/websocket.rs b/sdks/rust/src/websocket.rs index 2b15b85a2d8..79d600cc404 100644 --- a/sdks/rust/src/websocket.rs +++ b/sdks/rust/src/websocket.rs @@ -8,27 +8,37 @@ use std::mem; use std::sync::{Arc, Mutex}; use std::time::Duration; +#[cfg(not(feature = "web"))] use bytes::Bytes; -use futures::{SinkExt, StreamExt as _, TryStreamExt}; +#[cfg(not(feature = "web"))] +use futures::TryStreamExt; +use futures::{SinkExt, StreamExt as _}; use futures_channel::mpsc; use http::uri::{InvalidUri, Scheme, Uri}; use spacetimedb_client_api_messages::websocket as ws; use spacetimedb_lib::{bsatn, ConnectionId}; use thiserror::Error; -use tokio::task::JoinHandle; -use tokio::time::Instant; -use tokio::{net::TcpStream, runtime}; +#[cfg(not(feature = "web"))] +use tokio::{net::TcpStream, runtime, task::JoinHandle, time::Instant}; +#[cfg(not(feature = "web"))] use tokio_tungstenite::{ connect_async_with_config, tungstenite::client::IntoClientRequest, tungstenite::protocol::{Message as WebSocketMessage, WebSocketConfig}, MaybeTlsStream, WebSocketStream, }; +#[cfg(feature = "web")] +use tokio_tungstenite_wasm::{Message as WebSocketMessage, WebSocketStream}; use crate::compression::decompress_server_message; use crate::db_connection::debug_log; use crate::metrics::CLIENT_METRICS; +#[cfg(not(feature = "web"))] +type TokioTungsteniteError = tokio_tungstenite::tungstenite::Error; +#[cfg(feature = "web")] +type TokioTungsteniteError = tokio_tungstenite_wasm::Error; + #[derive(Error, Debug, Clone)] pub enum UriError { #[error("Unknown URI scheme {scheme}, expected http, https, ws or wss")] @@ -60,7 +70,7 @@ pub enum WsError { uri: Uri, #[source] // `Arc` is required for `Self: Clone`, as `tungstenite::Error: !Clone`. - source: Arc, + source: Arc, }, #[error("Received empty raw message, but valid messages always start with a one-byte compression flag")] @@ -82,11 +92,18 @@ pub enum WsError { #[error("Unrecognized compression scheme: {scheme:#x}")] UnknownCompressionScheme { scheme: u8 }, + + #[cfg(feature = "web")] + #[error("Token verification error: {0}")] + TokenVerification(String), } pub(crate) struct WsConnection { db_name: Box, + #[cfg(not(feature = "web"))] sock: WebSocketStream>, + #[cfg(feature = "web")] + sock: WebSocketStream, } fn parse_scheme(scheme: Option) -> Result { @@ -114,7 +131,29 @@ pub(crate) struct WsParams { pub confirmed: Option, } +#[cfg(not(feature = "web"))] fn make_uri(host: Uri, db_name: &str, connection_id: Option, params: WsParams) -> Result { + make_uri_impl(host, db_name, connection_id, params, None) +} + +#[cfg(feature = "web")] +fn make_uri( + host: Uri, + db_name: &str, + connection_id: Option, + params: WsParams, + token: Option<&str>, +) -> Result { + make_uri_impl(host, db_name, connection_id, params, token) +} + +fn make_uri_impl( + host: Uri, + db_name: &str, + connection_id: Option, + params: WsParams, + token: Option<&str>, +) -> Result { let mut parts = host.into_parts(); let scheme = parse_scheme(parts.scheme.take())?; parts.scheme = Some(scheme); @@ -158,6 +197,11 @@ fn make_uri(host: Uri, db_name: &str, connection_id: Option, param path.push_str(if confirmed { "true" } else { "false" }); } + // Specify the `token` param if needed + if let Some(token) = token { + path.push_str(&format!("&token={token}")); + } + parts.path_and_query = Some(path.parse().map_err(|source: InvalidUri| UriError::InvalidUri { source: Arc::new(source), })?); @@ -175,6 +219,7 @@ fn make_uri(host: Uri, db_name: &str, connection_id: Option, param // rather than having Tungstenite manage its own connections. Should this library do // the same? +#[cfg(not(feature = "web"))] fn make_request( host: Uri, db_name: &str, @@ -192,6 +237,7 @@ fn make_request( Ok(req) } +#[cfg(not(feature = "web"))] fn request_insert_protocol_header(req: &mut http::Request<()>) { req.headers_mut().insert( http::header::SEC_WEBSOCKET_PROTOCOL, @@ -199,6 +245,7 @@ fn request_insert_protocol_header(req: &mut http::Request<()>) { ); } +#[cfg(not(feature = "web"))] fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>) { if let Some(token) = token { let auth = ["Bearer ", token].concat().try_into().unwrap(); @@ -206,9 +253,57 @@ fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>) } } +#[cfg(feature = "web")] +async fn fetch_ws_token(host: &Uri, auth_token: &str) -> Result { + use gloo_net::http::{Method, RequestBuilder}; + use js_sys::{Reflect, JSON}; + use wasm_bindgen::{JsCast, JsValue}; + + let url = format!("{host}v1/identity/websocket-token"); + + // helpers to convert gloo_net::Error or JsValue into WsError::TokenVerification + let gloo_to_ws_err = |e: gloo_net::Error| match e { + gloo_net::Error::JsError(js_err) => WsError::TokenVerification(js_err.message), + gloo_net::Error::SerdeError(e) => WsError::TokenVerification(e.to_string()), + gloo_net::Error::GlooError(msg) => WsError::TokenVerification(msg), + }; + let js_to_ws_err = |e: JsValue| { + if let Some(err) = e.dyn_ref::() { + WsError::TokenVerification(err.message().into()) + } else if let Some(s) = e.as_string() { + WsError::TokenVerification(s) + } else { + WsError::TokenVerification(format!("{e:?}")) + } + }; + + let res = RequestBuilder::new(&url) + .method(Method::POST) + .header("Authorization", &format!("Bearer {auth_token}")) + .send() + .await + .map_err(gloo_to_ws_err)?; + + if !res.ok() { + return Err(WsError::TokenVerification(format!( + "HTTP error: {} {}", + res.status(), + res.status_text() + ))); + } + + let body = res.text().await.map_err(gloo_to_ws_err)?; + let json = JSON::parse(&body).map_err(js_to_ws_err)?; + let token_js = Reflect::get(&json, &JsValue::from_str("token")).map_err(js_to_ws_err)?; + token_js + .as_string() + .ok_or_else(|| WsError::TokenVerification("`token` parsing failed".into())) +} + /// If `res` evaluates to `Err(e)`, log a warning in the form `"{}: {:?}", $cause, e`. /// /// Could be trivially written as a function, but macro-ifying it preserves the source location of the log. +#[cfg(not(feature = "web"))] macro_rules! maybe_log_error { ($extra_logging:expr, $cause:expr, $res:expr) => { if let Err(e) = $res { @@ -220,6 +315,7 @@ macro_rules! maybe_log_error { } impl WsConnection { + #[cfg(not(feature = "web"))] pub(crate) async fn connect( host: Uri, db_name: &str, @@ -251,6 +347,34 @@ impl WsConnection { }) } + #[cfg(feature = "web")] + pub(crate) async fn connect( + host: Uri, + db_name: &str, + token: Option<&str>, + connection_id: Option, + params: WsParams, + ) -> Result { + let token = if let Some(auth_token) = token { + Some(fetch_ws_token(&host, auth_token).await?) + } else { + None + }; + + let uri = make_uri(host, db_name, connection_id, params, token.as_deref())?; + let sock = tokio_tungstenite_wasm::connect_with_protocols(&uri.to_string(), &[ws::v2::BIN_PROTOCOL]) + .await + .map_err(|source| WsError::Tungstenite { + uri, + source: Arc::new(source), + })?; + + Ok(WsConnection { + db_name: db_name.into(), + sock, + }) + } + pub(crate) fn parse_response(bytes: &[u8]) -> Result { let bytes = &*decompress_server_message(bytes)?; bsatn::from_slice(bytes).map_err(|source| WsError::DeserializeMessage { source }) @@ -260,6 +384,7 @@ impl WsConnection { WebSocketMessage::Binary(bsatn::to_vec(&msg).unwrap().into()) } + #[cfg(not(feature = "web"))] async fn message_loop( mut self, incoming_messages: mpsc::UnboundedSender, @@ -405,6 +530,7 @@ impl WsConnection { } } + #[cfg(not(feature = "web"))] pub(crate) fn spawn_message_loop( self, runtime: &runtime::Handle, @@ -419,4 +545,99 @@ impl WsConnection { let handle = runtime.spawn(self.message_loop(incoming_send, outgoing_recv, extra_logging)); (handle, incoming_recv, outgoing_send) } + + #[cfg(feature = "web")] + pub(crate) fn spawn_message_loop( + self, + ) -> ( + mpsc::UnboundedReceiver, + mpsc::UnboundedSender, + ) { + let websocket_received = CLIENT_METRICS.websocket_received.with_label_values(&self.db_name); + let websocket_received_msg_size = CLIENT_METRICS + .websocket_received_msg_size + .with_label_values(&self.db_name); + let record_metrics = move |msg_size: usize| { + websocket_received.inc(); + websocket_received_msg_size.observe(msg_size as f64); + }; + + let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + + let (mut ws_writer, ws_reader) = self.sock.split(); + + wasm_bindgen_futures::spawn_local(async move { + let mut incoming = ws_reader.fuse(); + let mut outgoing = outgoing_rx.fuse(); + + loop { + futures::select! { + // 1) inbound WS frames + inbound = incoming.next() => match inbound { + Some(Err(tokio_tungstenite_wasm::Error::ConnectionClosed)) | None => { + gloo_console::log!("Connection closed"); + break; + }, + + Some(Ok(WebSocketMessage::Binary(bytes))) => { + record_metrics(bytes.len()); + // parse + forward into `incoming_tx` + match Self::parse_response(&bytes) { + Ok(msg) => if let Err(_e) = incoming_tx.unbounded_send(msg) { + gloo_console::warn!("Incoming receiver dropped."); + break; + }, + Err(e) => { + gloo_console::warn!( + "Error decoding WebSocketMessage::Binay payload: ", + format!("{:?}", e) + ); + }, + } + }, + + Some(Ok(WebSocketMessage::Close(r))) => { + let reason: String = if let Some(r) = r { + format!("{}:{:?}", r, r.code) + } else {String::default()}; + gloo_console::warn!("Connection Closed.", reason); + let _ = ws_writer.close().await; + break; + }, + + Some(Err(e)) => { + gloo_console::warn!( + "Error reading message from read WebSocket stream: ", + format!("{:?}",e) + ); + break; + }, + + Some(Ok(other)) => { + record_metrics(other.len()); + gloo_console::warn!("Unexpected WebSocket message: ", format!("{:?}",other)); + } + }, + + // 2) outbound messages + outbound = outgoing.next() => if let Some(client_msg) = outbound { + let raw = Self::encode_message(client_msg); + if let Err(e) = ws_writer.send(raw).await { + gloo_console::warn!("Error sending outgoing message:", format!("{:?}",e)); + break; + } + } else { + // channel closed, so we're done sending + if let Err(e) = ws_writer.close().await { + gloo_console::warn!("Error sending close frame:", format!("{:?}", e)); + } + break; + }, + } + } + }); + + (incoming_rx, outgoing_tx) + } } diff --git a/sdks/rust/tests/connect_disconnect_client/Cargo.toml b/sdks/rust/tests/connect_disconnect_client/Cargo.toml index 126d8568ac5..280cc840d45 100644 --- a/sdks/rust/tests/connect_disconnect_client/Cargo.toml +++ b/sdks/rust/tests/connect_disconnect_client/Cargo.toml @@ -6,10 +6,41 @@ license-file = "LICENSE" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = ["dep:tokio"] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:console_error_panic_hook", + "dep:gloo-timers", + "dep:futures", +] + +[[bin]] +name = "connect_disconnect_client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } test-counter = { path = "../test-counter" } anyhow.workspace = true +futures = { workspace = true, optional = true } +tokio = { workspace = true, optional = true } +gloo-timers = { version = "0.3.0", features = ["futures"], optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } +console_error_panic_hook = { version = "0.1.7", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/connect_disconnect_client/src/lib.rs b/sdks/rust/tests/connect_disconnect_client/src/lib.rs new file mode 100644 index 00000000000..d6f5cd68b52 --- /dev/null +++ b/sdks/rust/tests/connect_disconnect_client/src/lib.rs @@ -0,0 +1,17 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(_test_name: String, db_name: String) { + console_error_panic_hook::set_once(); + // The shared wasm test harness always passes `(test_name, db_name)`, even for + // fixed-flow clients like this one that ignore the selector. + cli::set_web_db_name(db_name); + cli::dispatch().await; +} diff --git a/sdks/rust/tests/connect_disconnect_client/src/main.rs b/sdks/rust/tests/connect_disconnect_client/src/main.rs index 9b17f1bf9c1..7c42d0827f6 100644 --- a/sdks/rust/tests/connect_disconnect_client/src/main.rs +++ b/sdks/rust/tests/connect_disconnect_client/src/main.rs @@ -1,18 +1,46 @@ -mod module_bindings; +pub(crate) mod module_bindings; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use std::sync::OnceLock; use module_bindings::*; -use spacetimedb_sdk::{DbContext, Table}; +use spacetimedb_sdk::{DbConnectionBuilder, DbContext, Table}; use test_counter::TestCounter; const LOCALHOST: &str = "http://localhost:3000"; +#[cfg(all(target_arch = "wasm32", feature = "web"))] +static WEB_DB_NAME: OnceLock = OnceLock::new(); + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +pub(crate) fn set_web_db_name(db_name: String) { + WEB_DB_NAME.set(db_name).expect("WASM DB name was already initialized"); +} + fn db_name_or_panic() -> String { - std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") + #[cfg(all(target_arch = "wasm32", feature = "web"))] + { + return WEB_DB_NAME + .get() + .cloned() + .expect("Failed to read db name from wasm runner"); + } + + #[cfg(not(all(target_arch = "wasm32", feature = "web")))] + { + std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") + } } +#[cfg(not(target_arch = "wasm32"))] fn main() { + // Keep a single async execution path so native and wasm exercise the same logic. + tokio::runtime::Runtime::new().unwrap().block_on(dispatch()); +} + +pub(crate) async fn dispatch() { let disconnect_test_counter = TestCounter::new(); let disconnect_result = disconnect_test_counter.add_test("disconnect"); @@ -56,18 +84,21 @@ fn main() { Some(err) => disconnect_result(Err(anyhow::anyhow!("{err:?}"))), None => disconnect_result(Ok(())), } - }) - .build() - .unwrap(); + }); + let connection = build_connection(connection).await; + #[cfg(not(target_arch = "wasm32"))] let join_handle = connection.run_threaded(); + #[cfg(target_arch = "wasm32")] + connection.run_background_task(); - connect_test_counter.wait_for_all(); + wait_for_all(&connect_test_counter).await; - connection.disconnect().unwrap(); + disconnect_connection(&connection).await; + #[cfg(not(target_arch = "wasm32"))] join_handle.join().unwrap(); - disconnect_test_counter.wait_for_all(); + wait_for_all(&disconnect_test_counter).await; let reconnect_test_counter = TestCounter::new(); let reconnected_result = reconnect_test_counter.add_test("on_reconnect"); @@ -79,9 +110,8 @@ fn main() { reconnected_result(Ok(())); }) .with_database_name(db_name_or_panic()) - .with_uri(LOCALHOST) - .build() - .unwrap(); + .with_uri(LOCALHOST); + let new_connection = build_connection(new_connection).await; new_connection .subscription_builder() @@ -103,7 +133,51 @@ fn main() { .on_error(|_ctx, error| panic!("subscription on_error: {error:?}")) .subscribe("SELECT * FROM disconnected"); - new_connection.run_threaded(); + #[cfg(not(target_arch = "wasm32"))] + let reconnect_join_handle = new_connection.run_threaded(); + #[cfg(target_arch = "wasm32")] + new_connection.run_background_task(); + + wait_for_all(&reconnect_test_counter).await; + + // The second connection has no disconnect assertion, but it still needs an explicit + // shutdown on wasm so the background task releases its websocket before the test exits. + disconnect_connection(&new_connection).await; + #[cfg(not(target_arch = "wasm32"))] + reconnect_join_handle.join().unwrap(); +} + +async fn wait_for_all(test_counter: &std::sync::Arc) { + #[cfg(target_arch = "wasm32")] + { + // wasm/web callbacks run on the JS event loop, so this wait must stay async. + test_counter.wait_for_all_async().await; + return; + } + + #[cfg(not(target_arch = "wasm32"))] + test_counter.wait_for_all(); +} + +#[cfg(not(target_arch = "wasm32"))] +async fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + builder.build().unwrap() +} + +#[cfg(target_arch = "wasm32")] +async fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + // Web builds use async connection setup, so awaiting here avoids blocking the event loop + // before websocket callbacks have a chance to run. + builder.build().await.unwrap() +} + +async fn disconnect_connection(connection: &DbConnection) { + connection.disconnect().unwrap(); - reconnect_test_counter.wait_for_all(); + #[cfg(target_arch = "wasm32")] + { + // Yield once so the queued disconnect mutation is processed by the background task + // before the wasm test function returns to Node. + gloo_timers::future::TimeoutFuture::new(0).await; + } } diff --git a/sdks/rust/tests/event-table-client/Cargo.toml b/sdks/rust/tests/event-table-client/Cargo.toml index f6502644119..b5b15b11889 100644 --- a/sdks/rust/tests/event-table-client/Cargo.toml +++ b/sdks/rust/tests/event-table-client/Cargo.toml @@ -3,11 +3,39 @@ name = "event-table-client" version.workspace = true edition.workspace = true +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = [ + "dep:env_logger", +] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:futures", +] + +[[bin]] +name = "event-table-client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } test-counter = { path = "../test-counter" } anyhow.workspace = true -env_logger.workspace = true +env_logger = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/event-table-client/src/lib.rs b/sdks/rust/tests/event-table-client/src/lib.rs new file mode 100644 index 00000000000..9dd50d80e2b --- /dev/null +++ b/sdks/rust/tests/event-table-client/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(test_name: String) { + cli::dispatch(&test_name); +} diff --git a/sdks/rust/tests/event-table-client/src/main.rs b/sdks/rust/tests/event-table-client/src/main.rs index 5edb1fdfe0a..700bb486ed9 100644 --- a/sdks/rust/tests/event-table-client/src/main.rs +++ b/sdks/rust/tests/event-table-client/src/main.rs @@ -1,10 +1,10 @@ #[allow(clippy::too_many_arguments)] #[allow(clippy::large_enum_variant)] -mod module_bindings; +pub(crate) mod module_bindings; use module_bindings::*; -use spacetimedb_sdk::{DbContext, Event, EventTable}; +use spacetimedb_sdk::{DbConnectionBuilder, DbContext, Event, EventTable}; use std::sync::atomic::{AtomicU32, Ordering}; use test_counter::TestCounter; @@ -17,6 +17,7 @@ fn db_name_or_panic() -> String { /// Register a panic hook which will exit the process whenever any thread panics. /// /// This allows us to fail tests by panicking in callbacks. +#[cfg(not(target_arch = "wasm32"))] fn exit_on_panic() { let default_hook = std::panic::take_hook(); std::panic::set_hook(Box::new(move |panic_info| { @@ -41,6 +42,7 @@ macro_rules! assert_eq_or_bail { }}; } +#[cfg(not(target_arch = "wasm32"))] fn main() { env_logger::init(); exit_on_panic(); @@ -49,6 +51,10 @@ fn main() { .nth(1) .expect("Pass a test name as a command-line argument to the test client"); + dispatch(&test); +} + +pub(crate) fn dispatch(test: &str) { match &*test { "event-table" => exec_event_table(), "multiple-events" => exec_multiple_events(), @@ -58,6 +64,16 @@ fn main() { } } +#[cfg(not(target_arch = "wasm32"))] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + builder.build().unwrap() +} + +#[cfg(target_arch = "wasm32")] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + futures::executor::block_on(builder.build()).unwrap() +} + fn connect_then( test_counter: &std::sync::Arc, callback: impl FnOnce(&DbConnection) + Send + 'static, @@ -71,10 +87,12 @@ fn connect_then( callback(ctx); connected_result(Ok(())); }) - .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")) - .build() - .unwrap(); + .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")); + let conn = build_connection(conn); + #[cfg(not(target_arch = "wasm32"))] conn.run_threaded(); + #[cfg(target_arch = "wasm32")] + conn.run_background_task(); conn } diff --git a/sdks/rust/tests/procedure-client/Cargo.toml b/sdks/rust/tests/procedure-client/Cargo.toml index 665d4b0582d..7678a691eee 100644 --- a/sdks/rust/tests/procedure-client/Cargo.toml +++ b/sdks/rust/tests/procedure-client/Cargo.toml @@ -6,13 +6,41 @@ license-file = "LICENSE" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = [ + "dep:env_logger", +] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:futures", +] + +[[bin]] +name = "procedure-client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } spacetimedb-lib.workspace = true test-counter = { path = "../test-counter" } anyhow.workspace = true -env_logger.workspace = true +env_logger = { workspace = true, optional = true } serde_json.workspace = true +futures = { workspace = true, optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/procedure-client/src/lib.rs b/sdks/rust/tests/procedure-client/src/lib.rs new file mode 100644 index 00000000000..9dd50d80e2b --- /dev/null +++ b/sdks/rust/tests/procedure-client/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(test_name: String) { + cli::dispatch(&test_name); +} diff --git a/sdks/rust/tests/procedure-client/src/main.rs b/sdks/rust/tests/procedure-client/src/main.rs index cdafbff4dc9..b99b6eef0b4 100644 --- a/sdks/rust/tests/procedure-client/src/main.rs +++ b/sdks/rust/tests/procedure-client/src/main.rs @@ -1,4 +1,4 @@ -mod module_bindings; +pub(crate) mod module_bindings; use core::time::Duration; @@ -13,6 +13,7 @@ const LOCALHOST: &str = "http://localhost:3000"; /// Register a panic hook which will exit the process whenever any thread panics. /// /// This allows us to fail tests by panicking in callbacks. +#[cfg(not(target_arch = "wasm32"))] fn exit_on_panic() { // The default panic hook is responsible for printing the panic message and backtrace to stderr. // Grab a handle on it, and invoke it in our custom hook before exiting. @@ -30,6 +31,7 @@ fn db_name_or_panic() -> String { std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") } +#[cfg(not(target_arch = "wasm32"))] fn main() { env_logger::init(); exit_on_panic(); @@ -38,6 +40,10 @@ fn main() { .nth(1) .expect("Pass a test name as a command-line argument to the test client"); + dispatch(&test); +} + +pub(crate) fn dispatch(test: &str) { match &*test { "procedure-return-values" => exec_procedure_return_values(), "procedure-observe-panic" => exec_procedure_panic(), @@ -51,6 +57,16 @@ fn main() { } } +#[cfg(not(target_arch = "wasm32"))] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + builder.build().unwrap() +} + +#[cfg(target_arch = "wasm32")] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + futures::executor::block_on(builder.build()).unwrap() +} + fn assert_table_empty(tbl: T) -> anyhow::Result<()> { let count = tbl.count(); if count != 0 { @@ -88,8 +104,11 @@ fn connect_with_then( connected_result(Ok(())); }) .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")); - let conn = with_builder(builder).build().unwrap(); + let conn = build_connection(with_builder(builder)); + #[cfg(not(target_arch = "wasm32"))] conn.run_threaded(); + #[cfg(target_arch = "wasm32")] + conn.run_background_task(); conn } diff --git a/sdks/rust/tests/test-client/Cargo.toml b/sdks/rust/tests/test-client/Cargo.toml index 7a7167234a5..b4a3337d8c9 100644 --- a/sdks/rust/tests/test-client/Cargo.toml +++ b/sdks/rust/tests/test-client/Cargo.toml @@ -6,13 +6,49 @@ license-file = "LICENSE" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = [ + "dep:tokio", + "dep:env_logger", + "dep:rand", +] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:console_error_panic_hook", + "dep:gloo-timers", + "dep:futures", + "dep:rand", +] + +[[bin]] +name = "test-client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } -test-counter = { path = "../test-counter" } -tokio.workspace = true anyhow.workspace = true -env_logger.workspace = true -rand.workspace = true + +test-counter = { path = "../test-counter" } +tokio = { workspace = true, optional = true } +env_logger = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +gloo-timers = { version = "0.3.0", features = ["futures"], optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } +console_error_panic_hook = { version = "0.1.7", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/test-client/src/lib.rs b/sdks/rust/tests/test-client/src/lib.rs new file mode 100644 index 00000000000..049a22aed54 --- /dev/null +++ b/sdks/rust/tests/test-client/src/lib.rs @@ -0,0 +1,17 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +pub(crate) use cli::module_bindings; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(test_name: String, db_name: String) { + console_error_panic_hook::set_once(); + cli::set_web_db_name(db_name); + cli::dispatch(&test_name).await; +} diff --git a/sdks/rust/tests/test-client/src/main.rs b/sdks/rust/tests/test-client/src/main.rs index 89dac3b00c7..ec7c981fd8e 100644 --- a/sdks/rust/tests/test-client/src/main.rs +++ b/sdks/rust/tests/test-client/src/main.rs @@ -1,19 +1,23 @@ #[allow(clippy::too_many_arguments)] #[allow(clippy::large_enum_variant)] -mod module_bindings; +pub(crate) mod module_bindings; use core::fmt::Display; -use core::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Barrier, Mutex}; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +#[cfg(target_arch = "wasm32")] +use std::cell::RefCell; +use std::sync::{Arc, Mutex}; use module_bindings::*; use rand::RngCore; +#[cfg(not(target_arch = "wasm32"))] +use spacetimedb_sdk::credentials; use spacetimedb_sdk::error::InternalError; use spacetimedb_sdk::TableWithPrimaryKey; use spacetimedb_sdk::{ - credentials, i256, u256, Compression, ConnectionId, DbConnectionBuilder, DbContext, Event, Identity, ReducerEvent, - Status, SubscriptionHandle, Table, TimeDuration, Timestamp, Uuid, + i256, u256, Compression, ConnectionId, DbConnectionBuilder, DbContext, Event, Identity, ReducerEvent, Status, + SubscriptionHandle, Table, TimeDuration, Timestamp, Uuid, }; use test_counter::TestCounter; @@ -28,13 +32,180 @@ use unique_test_table::{insert_then_delete_one, UniqueTestTable}; const LOCALHOST: &str = "http://localhost:3000"; +fn fixed_test_timestamp() -> Timestamp { + // `Timestamp::now()` is stubbed on `wasm32-unknown-unknown`, so client-side tests + // that need a timestamp value must use a deterministic literal instead of wall-clock time. + Timestamp::from_micros_since_unix_epoch(1_706_000_000_000_000) +} + +// ---- Test harness platform shim layer ---- +// +// Keep all intentional native-vs-wasm differences concentrated here so reviewers can reason +// about the rest of the file as ordinary shared test logic. The desired high-level behavior is: +// - all tests call the same async `connect*` and `wait_for_all` helpers +// - native keeps using its existing threaded SDK runtime +// - wasm keeps the JS event loop unblocked and preserves native-style connection lifetimes + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +static WEB_DB_NAME: std::sync::OnceLock = std::sync::OnceLock::new(); + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +pub(crate) fn set_web_db_name(db_name: String) { + WEB_DB_NAME.set(db_name).expect("WASM DB name was already initialized"); +} + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +thread_local! { + // Why wasm needs extra lifetime plumbing: + // - Native test clients call `run_threaded()`, so SDK websocket work keeps progressing on a + // dedicated background thread even if the test thread blocks in `wait_for_all()`. + // - wasm/web runs the client and its websocket callbacks on the same single-threaded JS event + // loop as the test body. There is no equivalent background OS thread here. + // - That means an early last-handle drop disconnects the websocket immediately, and a blocking + // wait would freeze the very event loop that is supposed to deliver the remaining callbacks. + // Retain connections until the async wait completes so wasm observes the same logical lifetime + // that the native tests were implicitly relying on. + // Many native-first tests intentionally ignore the returned connection and rely on the + // background message loop to stay alive until `wait_for_all()` finishes. Retain a clone on + // wasm so the final-drop disconnect cleanup does not self-terminate those tests early. + static RETAINED_WASM_CONNECTIONS: RefCell> = const { RefCell::new(Vec::new()) }; +} + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +fn retain_connection_until_wait(connection: &ManagedConnection) { + RETAINED_WASM_CONNECTIONS.with(|connections| connections.borrow_mut().push(connection.clone())); +} + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +fn release_retained_connections() { + RETAINED_WASM_CONNECTIONS.with(|connections| connections.borrow_mut().clear()); +} + +#[cfg(not(all(target_arch = "wasm32", feature = "web")))] +fn retain_connection_until_wait(_: &ManagedConnection) {} + +#[cfg(not(all(target_arch = "wasm32", feature = "web")))] +fn release_retained_connections() {} + +// `ManagedConnection` exists to separate test logic from connection-lifetime mechanics. +// The generated `DbConnection` is a single owned handle, but this harness now needs two extra +// properties: +// - shared ownership, so wasm can retain a connection until `wait_for_all()` even when the test +// itself ignores the returned value, and +// - final-owner disconnect semantics, so only the last live handle shuts down the websocket. +// Wrapping the connection in an `Arc` and re-exposing `DbContext` here keeps those lifecycle rules +// local to the shim layer instead of forcing the individual tests to reason about them. +#[derive(Clone)] +struct ManagedConnection(Arc); + +impl ManagedConnection { + fn new(connection: DbConnection) -> Self { + Self(Arc::new(connection)) + } +} + +impl core::ops::Deref for ManagedConnection { + type Target = DbConnection; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl DbContext for ManagedConnection { + type DbView = ::DbView; + type Reducers = ::Reducers; + type Procedures = ::Procedures; + type SubscriptionBuilder = ::SubscriptionBuilder; + + fn db(&self) -> &Self::DbView { + &self.0.db + } + + fn reducers(&self) -> &Self::Reducers { + &self.0.reducers + } + + fn procedures(&self) -> &Self::Procedures { + &self.0.procedures + } + + fn is_active(&self) -> bool { + self.0.is_active() + } + + fn disconnect(&self) -> spacetimedb_sdk::Result<()> { + self.0.disconnect() + } + + fn subscription_builder(&self) -> Self::SubscriptionBuilder { + self.0.subscription_builder() + } + + fn try_identity(&self) -> Option { + self.0.try_identity() + } + + fn connection_id(&self) -> ConnectionId { + self.0.connection_id() + } + + fn try_connection_id(&self) -> Option { + self.0.try_connection_id() + } +} + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +impl Drop for ManagedConnection { + fn drop(&mut self) { + if Arc::strong_count(&self.0) == 1 { + // Only the final owner should disconnect. wasm tests now clone connections into a + // temporary retention list so ignored return values stay alive until `wait_for_all()`. + // Disconnecting from any earlier clone drop would let the websocket shut down while + // the test still expects subscription or reducer callbacks to arrive. + let _ = self.0.disconnect(); + } + } +} + +/// Read the per-test database name from the native process environment or the wasm runner shim. +/// Keeping this lookup centralized avoids sprinkling target-specific configuration through tests. fn db_name_or_panic() -> String { - std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") + #[cfg(all(target_arch = "wasm32", feature = "web"))] + { + return WEB_DB_NAME + .get() + .cloned() + .expect("Failed to read db name from wasm runner"); + } + + #[cfg(not(all(target_arch = "wasm32", feature = "web")))] + { + std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn start_connection_runtime(connection: &ManagedConnection) { + // Native keeps websocket/callback work on a dedicated SDK thread, so the test body can block + // later without preventing callbacks from being delivered. + connection.run_threaded(); + retain_connection_until_wait(connection); +} + +#[cfg(target_arch = "wasm32")] +fn start_connection_runtime(connection: &ManagedConnection) { + // wasm/web has no background OS thread in this harness. Start the async message loop and keep + // the connection retained until `wait_for_all()` so native-style ignored return values do not + // disconnect the websocket before the expected callbacks have run. + connection.run_background_task(); + retain_connection_until_wait(connection); } /// Register a panic hook which will exit the process whenever any thread panics. /// /// This allows us to fail tests by panicking in callbacks. +#[cfg(not(target_arch = "wasm32"))] fn exit_on_panic() { // The default panic hook is responsible for printing the panic message and backtrace to stderr. // Grab a handle on it, and invoke it in our custom hook before exiting. @@ -64,6 +235,7 @@ macro_rules! assert_eq_or_bail { }}; } +#[cfg(not(target_arch = "wasm32"))] fn main() { env_logger::init(); exit_on_panic(); @@ -71,77 +243,82 @@ fn main() { let test = std::env::args() .nth(1) .expect("Pass a test name as a command-line argument to the test client"); + // Keep a single async dispatch entrypoint so wasm and native execute + // the same test bodies; native blocks at the CLI boundary only. + tokio::runtime::Runtime::new().unwrap().block_on(dispatch(&test)); +} - match &*test { - "insert-primitive" => exec_insert_primitive(), - "subscribe-and-cancel" => exec_subscribe_and_cancel(), - "subscribe-and-unsubscribe" => exec_subscribe_and_unsubscribe(), - "subscription-error-smoke-test" => exec_subscription_error_smoke_test(), - "delete-primitive" => exec_delete_primitive(), - "update-primitive" => exec_update_primitive(), +pub(crate) async fn dispatch(test: &str) { + match test { + "insert-primitive" => exec_insert_primitive().await, + "subscribe-and-cancel" => exec_subscribe_and_cancel().await, + "subscribe-and-unsubscribe" => exec_subscribe_and_unsubscribe().await, + "subscription-error-smoke-test" => exec_subscription_error_smoke_test().await, + "delete-primitive" => exec_delete_primitive().await, + "update-primitive" => exec_update_primitive().await, - "insert-identity" => exec_insert_identity(), - "insert-caller-identity" => exec_insert_caller_identity(), - "delete-identity" => exec_delete_identity(), - "update-identity" => exec_update_identity(), + "insert-identity" => exec_insert_identity().await, + "insert-caller-identity" => exec_insert_caller_identity().await, + "delete-identity" => exec_delete_identity().await, + "update-identity" => exec_update_identity().await, - "insert-connection-id" => exec_insert_connection_id(), - "insert-caller-connection-id" => exec_insert_caller_connection_id(), - "delete-connection-id" => exec_delete_connection_id(), - "update-connection-id" => exec_update_connection_id(), + "insert-connection-id" => exec_insert_connection_id().await, + "insert-caller-connection-id" => exec_insert_caller_connection_id().await, + "delete-connection-id" => exec_delete_connection_id().await, + "update-connection-id" => exec_update_connection_id().await, - "insert-timestamp" => exec_insert_timestamp(), - "insert-call-timestamp" => exec_insert_call_timestamp(), + "insert-timestamp" => exec_insert_timestamp().await, + "insert-call-timestamp" => exec_insert_call_timestamp().await, - "insert-uuid" => exec_insert_uuid(), - "insert-call-uuid-v4" => exec_insert_caller_uuid_v4(), - "insert-call-uuid-v7" => exec_insert_caller_uuid_v7(), - "delete-uuid" => exec_delete_uuid(), - "update-uuid" => exec_update_uuid(), + "insert-uuid" => exec_insert_uuid().await, + "insert-call-uuid-v4" => exec_insert_caller_uuid_v4().await, + "insert-call-uuid-v7" => exec_insert_caller_uuid_v7().await, + "delete-uuid" => exec_delete_uuid().await, + "update-uuid" => exec_update_uuid().await, - "on-reducer" => exec_on_reducer(), - "fail-reducer" => exec_fail_reducer(), + "on-reducer" => exec_on_reducer().await, + "fail-reducer" => exec_fail_reducer().await, - "insert-vec" => exec_insert_vec(), - "insert-option-some" => exec_insert_option_some(), - "insert-option-none" => exec_insert_option_none(), - "insert-struct" => exec_insert_struct(), - "insert-simple-enum" => exec_insert_simple_enum(), - "insert-enum-with-payload" => exec_insert_enum_with_payload(), + "insert-vec" => exec_insert_vec().await, + "insert-option-some" => exec_insert_option_some().await, + "insert-option-none" => exec_insert_option_none().await, + "insert-struct" => exec_insert_struct().await, + "insert-simple-enum" => exec_insert_simple_enum().await, + "insert-enum-with-payload" => exec_insert_enum_with_payload().await, - "insert-delete-large-table" => exec_insert_delete_large_table(), + "insert-delete-large-table" => exec_insert_delete_large_table().await, - "insert-primitives-as-strings" => exec_insert_primitives_as_strings(), + "insert-primitives-as-strings" => exec_insert_primitives_as_strings().await, // "resubscribe" => exec_resubscribe(), // - "reauth-part-1" => exec_reauth_part_1(), - "reauth-part-2" => exec_reauth_part_2(), + "reauth-part-1" => exec_reauth_part_1().await, + "reauth-part-2" => exec_reauth_part_2().await, - "should-fail" => exec_should_fail(), + "should-fail" => exec_should_fail().await, - "reconnect-different-connection-id" => exec_reconnect_different_connection_id(), - "caller-always-notified" => exec_caller_always_notified(), + "reconnect-different-connection-id" => exec_reconnect_different_connection_id().await, + "caller-always-notified" => exec_caller_always_notified().await, - "subscribe-all-select-star" => exec_subscribe_all_select_star(), + "subscribe-all-select-star" => exec_subscribe_all_select_star().await, "caller-alice-receives-reducer-callback-but-not-bob" => { - exec_caller_alice_receives_reducer_callback_but_not_bob() + exec_caller_alice_receives_reducer_callback_but_not_bob().await } - "row-deduplication" => exec_row_deduplication(), - "row-deduplication-join-r-and-s" => exec_row_deduplication_join_r_and_s(), - "row-deduplication-r-join-s-and-r-joint" => exec_row_deduplication_r_join_s_and_r_join_t(), - "test-lhs-join-update" => test_lhs_join_update(), - "test-lhs-join-update-disjoint-queries" => test_lhs_join_update_disjoint_queries(), - "test-intra-query-bag-semantics-for-join" => test_intra_query_bag_semantics_for_join(), - "two-different-compression-algos" => exec_two_different_compression_algos(), - "test-parameterized-subscription" => test_parameterized_subscription(), - "test-rls-subscription" => test_rls_subscription(), - "pk-simple-enum" => exec_pk_simple_enum(), - "indexed-simple-enum" => exec_indexed_simple_enum(), - - "overlapping-subscriptions" => exec_overlapping_subscriptions(), - - "sorted-uuids-insert" => exec_sorted_uuids_insert(), + "row-deduplication" => exec_row_deduplication().await, + "row-deduplication-join-r-and-s" => exec_row_deduplication_join_r_and_s().await, + "row-deduplication-r-join-s-and-r-joint" => exec_row_deduplication_r_join_s_and_r_join_t().await, + "test-lhs-join-update" => test_lhs_join_update().await, + "test-lhs-join-update-disjoint-queries" => test_lhs_join_update_disjoint_queries().await, + "test-intra-query-bag-semantics-for-join" => test_intra_query_bag_semantics_for_join().await, + "two-different-compression-algos" => exec_two_different_compression_algos().await, + "test-parameterized-subscription" => test_parameterized_subscription().await, + "test-rls-subscription" => test_rls_subscription().await, + "pk-simple-enum" => exec_pk_simple_enum().await, + "indexed-simple-enum" => exec_indexed_simple_enum().await, + + "overlapping-subscriptions" => exec_overlapping_subscriptions().await, + + "sorted-uuids-insert" => exec_sorted_uuids_insert().await, _ => panic!("Unknown test: {test}"), } @@ -387,12 +564,12 @@ const SUBSCRIBE_ALL: &[&str] = &[ "SELECT * FROM table_holds_table;", ]; -fn connect_with_then( +async fn connect_with_then( test_counter: &std::sync::Arc, on_connect_suffix: &str, with_builder: impl FnOnce(DbConnectionBuilder) -> DbConnectionBuilder, callback: impl FnOnce(&DbConnection) + Send + 'static, -) -> DbConnection { +) -> ManagedConnection { let connected_result = test_counter.add_test(format!("on_connect_{on_connect_suffix}")); let name = db_name_or_panic(); let builder = DbConnection::builder() @@ -403,20 +580,46 @@ fn connect_with_then( connected_result(Ok(())); }) .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")); - let conn = with_builder(builder).build().unwrap(); - conn.run_threaded(); + let conn = build_connection(with_builder(builder)).await; + start_connection_runtime(&conn); conn } -fn connect_then( +async fn connect_then( test_counter: &std::sync::Arc, callback: impl FnOnce(&DbConnection) + Send + 'static, -) -> DbConnection { - connect_with_then(test_counter, "", |x| x, callback) +) -> ManagedConnection { + connect_with_then(test_counter, "", |x| x, callback).await } -fn connect(test_counter: &std::sync::Arc) -> DbConnection { - connect_then(test_counter, |_| {}) +async fn connect(test_counter: &std::sync::Arc) -> ManagedConnection { + connect_then(test_counter, |_| {}).await +} + +async fn wait_for_all(test_counter: &std::sync::Arc) { + // Use one shared async wait entrypoint even on native. The TestCounter implementation hides + // the native blocking wait vs wasm event-loop-friendly poll, which keeps test bodies uniform. + test_counter.wait_for_all_async().await; + // This is a no-op on native and the wasm-side lifetime release point on web. + release_retained_connections(); +} + +#[cfg(not(target_arch = "wasm32"))] +async fn build_connection(builder: DbConnectionBuilder) -> ManagedConnection { + // Keep the helper async even though native `build()` is synchronous so shared callers do not + // need target-specific branches. + ManagedConnection::new(builder.build().unwrap()) +} + +#[cfg(target_arch = "wasm32")] +async fn build_connection(builder: DbConnectionBuilder) -> ManagedConnection { + // Why this differs from native: + // - In the SDK, `DbConnectionBuilder::build` is sync on non-web builds, + // but async on `feature = "web"` because the websocket/token setup uses + // wasm/web async primitives. + // - We therefore keep the helper async and await directly so wasm stays + // non-blocking and can make forward progress on the JS event loop. + ManagedConnection::new(builder.build().await.unwrap()) } fn subscribe_all_then(ctx: &impl RemoteDbContext, callback: impl FnOnce(&SubscriptionEventContext) + Send + 'static) { @@ -448,7 +651,7 @@ fn reducer_callback_assert_committed( move |_ctx, outcome| assert_outcome_committed(reducer_name, outcome) } -fn exec_subscribe_and_cancel() { +async fn exec_subscribe_and_cancel() { let test_counter = TestCounter::new(); let cb = test_counter.add_test("unsubscribe_then_called"); connect_then(&test_counter, { @@ -471,11 +674,12 @@ fn exec_subscribe_and_cancel() { })) .unwrap(); } - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } -fn exec_subscribe_and_unsubscribe() { +async fn exec_subscribe_and_unsubscribe() { let test_counter = TestCounter::new(); let cb = test_counter.add_test("unsubscribe_then_called"); connect_then(&test_counter, { @@ -509,11 +713,12 @@ fn exec_subscribe_and_unsubscribe() { assert!(!handle.is_active()); assert!(!handle.is_ended()); } - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } -fn exec_subscription_error_smoke_test() { +async fn exec_subscription_error_smoke_test() { let test_counter = TestCounter::new(); let cb = test_counter.add_test("error_callback_is_called"); connect_then(&test_counter, { @@ -528,15 +733,16 @@ fn exec_subscription_error_smoke_test() { assert!(!handle.is_active()); assert!(!handle.is_ended()); } - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } /// This tests that we can: /// - Pass primitive types to reducers. /// - Deserialize primitive types in rows and in reducer arguments. /// - Observe `on_insert` callbacks with appropriate reducer events. -fn exec_insert_primitive() { +async fn exec_insert_primitive() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -569,13 +775,14 @@ fn exec_insert_primitive() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can observe `on_delete` callbacks. -fn exec_delete_primitive() { +async fn exec_delete_primitive() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -604,15 +811,16 @@ fn exec_delete_primitive() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can distinguish between `on_update` and `on_delete` callbacks for tables with primary keys. -fn exec_update_primitive() { +async fn exec_update_primitive() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -641,15 +849,16 @@ fn exec_update_primitive() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can serialize and deserialize `Identity` in various contexts. -fn exec_insert_identity() { +async fn exec_insert_identity() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -666,13 +875,14 @@ fn exec_insert_identity() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }) } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can retrieve and use the caller's `Identity` from the reducer context. -fn exec_insert_caller_identity() { +async fn exec_insert_caller_identity() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -690,14 +900,15 @@ fn exec_insert_caller_identity() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This test doesn't add much alongside `exec_insert_identity` and `exec_delete_primitive`, /// but it's here for symmetry. -fn exec_delete_identity() { +async fn exec_delete_identity() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -710,16 +921,17 @@ fn exec_delete_identity() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can distinguish between `on_delete` and `on_update` events /// for tables with `Identity` primary keys. -fn exec_update_identity() { +async fn exec_update_identity() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -732,15 +944,16 @@ fn exec_update_identity() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can serialize and deserialize `ConnectionId` in various contexts. -fn exec_insert_connection_id() { +async fn exec_insert_connection_id() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -753,13 +966,14 @@ fn exec_insert_connection_id() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize `ConnectionId` in various contexts. -fn exec_insert_caller_connection_id() { +async fn exec_insert_caller_connection_id() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -778,14 +992,15 @@ fn exec_insert_caller_connection_id() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This test doesn't add much alongside `exec_insert_connection_id` and `exec_delete_primitive`, /// but it's here for symmetry. -fn exec_delete_connection_id() { +async fn exec_delete_connection_id() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -798,20 +1013,21 @@ fn exec_delete_connection_id() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can distinguish between `on_delete` and `on_update` events /// for tables with `ConnectionId` primary keys. -fn exec_update_connection_id() { +async fn exec_update_connection_id() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -828,12 +1044,12 @@ fn exec_update_connection_id() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } -fn exec_insert_timestamp() { +async fn exec_insert_timestamp() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -841,17 +1057,18 @@ fn exec_insert_timestamp() { let test_counter = test_counter.clone(); move |ctx| { subscribe_all_then(ctx, move |ctx| { - insert_one::(ctx, &test_counter, Timestamp::now()); + insert_one::(ctx, &test_counter, fixed_test_timestamp()); sub_applied_nothing_result(assert_all_tables_empty(ctx)); }) } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } -fn exec_insert_call_timestamp() { +async fn exec_insert_call_timestamp() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -890,12 +1107,13 @@ fn exec_insert_call_timestamp() { }); sub_applied_nothing_result(assert_all_tables_empty(ctx)); } - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize `Uuid` in various contexts. -fn exec_insert_uuid() { +async fn exec_insert_uuid() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -908,13 +1126,14 @@ fn exec_insert_uuid() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize `Uuid` in various contexts. -fn exec_insert_caller_uuid_v4() { +async fn exec_insert_caller_uuid_v4() { /* let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -931,15 +1150,15 @@ fn exec_insert_caller_uuid_v4() { } }); - test_counter.wait_for_all();*/ + wait_for_all(&test_counter).await;*/ } /// This tests that we can serialize and deserialize `Uuid` in various contexts. -fn exec_insert_caller_uuid_v7() {} +async fn exec_insert_caller_uuid_v7() {} /// This test doesn't add much alongside `exec_insert_uuid` and `exec_delete_primitive`, /// but it's here for symmetry. -fn exec_delete_uuid() { +async fn exec_delete_uuid() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -952,20 +1171,21 @@ fn exec_delete_uuid() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can distinguish between `on_delete` and `on_update` events /// for tables with `Uuid` primary keys. -fn exec_update_uuid() { +async fn exec_update_uuid() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -976,17 +1196,17 @@ fn exec_update_uuid() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_all_tables_empty(&connection).unwrap(); } /// This tests that we can observe reducer callbacks for successful reducer runs. -fn exec_on_reducer() { +async fn exec_on_reducer() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; let reducer_result = test_counter.add_test("reducer-callback"); @@ -1027,17 +1247,17 @@ fn exec_on_reducer() { .unwrap(); }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can observe reducer callbacks for failed reducers. -fn exec_fail_reducer() { +async fn exec_fail_reducer() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); let reducer_success_result = test_counter.add_test("reducer-callback-success"); let reducer_fail_result = test_counter.add_test("reducer-callback-failure"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; let key = 128; let initial_data = 0xbeef; @@ -1127,15 +1347,15 @@ fn exec_fail_reducer() { .unwrap(); }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize `Vec` in various contexts. -fn exec_insert_vec() { +async fn exec_insert_vec() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1164,13 +1384,13 @@ fn exec_insert_vec() { insert_one::(ctx, &test_counter, vec![ctx.identity()]); insert_one::(ctx, &test_counter, vec![ctx.connection_id()]); - insert_one::(ctx, &test_counter, vec![Timestamp::now()]); + insert_one::(ctx, &test_counter, vec![fixed_test_timestamp()]); sub_applied_nothing_result(assert_all_tables_empty(ctx)); } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } fn every_primitive_struct() -> EveryPrimitiveStruct { @@ -1265,11 +1485,11 @@ fn large_table() -> LargeTable { /// /// Note that this must be a separate test from [`exec_insert_option_none`], /// as [`insert_one`] cannot handle running multiple tests for the same type in parallel. -fn exec_insert_option_some() { +async fn exec_insert_option_some() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1285,18 +1505,18 @@ fn exec_insert_option_some() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize `Option`s of various payload types which are `None`. /// /// Note that this must be a separate test from [`exec_insert_option_some`], /// as [`insert_one`] cannot handle running multiple tests for the same type in parallel. -fn exec_insert_option_none() { +async fn exec_insert_option_none() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1312,15 +1532,15 @@ fn exec_insert_option_none() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize structs in various contexts. -fn exec_insert_struct() { +async fn exec_insert_struct() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1339,15 +1559,15 @@ fn exec_insert_struct() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize C-style enums in various contexts. -fn exec_insert_simple_enum() { +async fn exec_insert_simple_enum() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1363,15 +1583,15 @@ fn exec_insert_simple_enum() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that we can serialize and deserialize sum types in various contexts. -fn exec_insert_enum_with_payload() { +async fn exec_insert_enum_with_payload() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1413,24 +1633,24 @@ fn exec_insert_enum_with_payload() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This tests that the test machinery itself is functional and can detect failures. -fn exec_should_fail() { +async fn exec_should_fail() { let test_counter = TestCounter::new(); let fail = test_counter.add_test("should-fail"); fail(Err(anyhow::anyhow!("This is an intentional failure"))); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// This test invokes a reducer with many arguments of many types, /// and observes a callback for an inserted table with many columns of many types. -fn exec_insert_delete_large_table() { +async fn exec_insert_delete_large_table() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1536,14 +1756,14 @@ fn exec_insert_delete_large_table() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } -fn exec_insert_primitives_as_strings() { +async fn exec_insert_primitives_as_strings() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1605,7 +1825,7 @@ fn exec_insert_primitives_as_strings() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } // /// This tests the behavior of re-subscribing @@ -1631,7 +1851,7 @@ fn exec_insert_primitives_as_strings() { // connect_result(connect(LOCALHOST, &name, None)); // // Wait for all previous checks before continuing. -// test_counter.wait_for_all(); +// wait_for_all(&test_counter).await; // // Insert 256 rows of `OneU8`. // // At this point, we should be subscribed to all of them. @@ -1647,7 +1867,7 @@ fn exec_insert_primitives_as_strings() { // insert_one_u_8(n as u8); // } // // Wait for all previous checks before continuing, -// test_counter.wait_for_all(); +// wait_for_all(&test_counter).await; // // and remove the callback now that we're done with it. // OneU8::remove_on_insert(on_insert_u8); // // Re-subscribe with a query that excludes the lower half of the `OneU8` rows, @@ -1679,7 +1899,7 @@ fn exec_insert_primitives_as_strings() { // let subscribe_result = test_counter.add_test("resubscribe"); // subscribe_result(subscribe(&["SELECT * FROM OneU8 WHERE n > 127"])); // // Wait before continuing, and remove callbacks. -// test_counter.wait_for_all(); +// wait_for_all(&test_counter).await; // OneU8::remove_on_delete(on_delete_verify); // OneU8::remove_on_insert(on_insert_panic); @@ -1708,16 +1928,18 @@ fn exec_insert_primitives_as_strings() { // }); // let subscribe_result = test_counter.add_test("resubscribe-again"); // subscribe_result(subscribe(&["SELECT * FROM OneU8"])); -// test_counter.wait_for_all(); +// wait_for_all(&test_counter).await; // } +#[cfg(not(target_arch = "wasm32"))] fn creds_store() -> credentials::File { credentials::File::new("rust-sdk-test") } /// Part of the `reauth` test, this connects to Spacetime to get new credentials, /// and saves them to a file. -fn exec_reauth_part_1() { +#[cfg(not(target_arch = "wasm32"))] +async fn exec_reauth_part_1() { let test_counter = TestCounter::new(); let name = db_name_or_panic(); @@ -1735,14 +1957,15 @@ fn exec_reauth_part_1() { .unwrap() .run_threaded(); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// Part of the `reauth` test, this loads credentials from a file, /// and passes them to `connect`. /// /// Must run after `exec_reauth_part_1`. -fn exec_reauth_part_2() { +#[cfg(not(target_arch = "wasm32"))] +async fn exec_reauth_part_2() { let test_counter = TestCounter::new(); let name = db_name_or_panic(); @@ -1770,71 +1993,90 @@ fn exec_reauth_part_2() { .unwrap() .run_threaded(); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; +} + +#[cfg(target_arch = "wasm32")] +async fn exec_reauth_part_1() { + // Native-only: requires file-backed credentials via `credentials::File`, + // which is unavailable in wasm/web. +} + +#[cfg(target_arch = "wasm32")] +async fn exec_reauth_part_2() { + // Native-only: requires persisted credentials from `exec_reauth_part_1`. } // Ensure a new connection gets a different connection id. -fn exec_reconnect_different_connection_id() { +async fn exec_reconnect_different_connection_id() { let initial_test_counter = TestCounter::new(); let initial_connect_result = initial_test_counter.add_test("connect"); let disconnect_test_counter = TestCounter::new(); let disconnect_result = disconnect_test_counter.add_test("disconnect"); - let initial_connection = DbConnection::builder() - .with_database_name(db_name_or_panic()) - .with_uri(LOCALHOST) - .on_connect_error(|_ctx, error| panic!("on_connect_error: {error:?}")) - .on_connect(move |_, _, _| { - initial_connect_result(Ok(())); - }) - .on_disconnect(|_, error| match error { - None => disconnect_result(Ok(())), - Some(err) => disconnect_result(Err(anyhow::anyhow!("{err:?}"))), - }) - .build() - .unwrap(); - + let initial_connection = build_connection( + DbConnection::builder() + .with_database_name(db_name_or_panic()) + .with_uri(LOCALHOST) + .on_connect_error(|_ctx, error| panic!("on_connect_error: {error:?}")) + .on_connect(move |_, _, _| { + initial_connect_result(Ok(())); + }) + .on_disconnect(|_, error| match error { + None => disconnect_result(Ok(())), + Some(err) => disconnect_result(Err(anyhow::anyhow!("{err:?}"))), + }), + ) + .await; + + #[cfg(not(target_arch = "wasm32"))] initial_connection.run_threaded(); + #[cfg(target_arch = "wasm32")] + initial_connection.run_background_task(); - initial_test_counter.wait_for_all(); + wait_for_all(&initial_test_counter).await; let my_connection_id = initial_connection.connection_id(); initial_connection.disconnect().unwrap(); - disconnect_test_counter.wait_for_all(); + wait_for_all(&disconnect_test_counter).await; let reconnect_test_counter = TestCounter::new(); let reconnect_result = reconnect_test_counter.add_test("reconnect"); let addr_after_reconnect_result = reconnect_test_counter.add_test("addr_after_reconnect"); - let re_connection = DbConnection::builder() - .with_database_name(db_name_or_panic()) - .with_uri(LOCALHOST) - .on_connect_error(|_ctx, error| panic!("on_connect_error: {error:?}")) - .on_connect(move |ctx, _, _| { - reconnect_result(Ok(())); - let run_checks = || { - // A new connection should have a different connection id. - anyhow::ensure!(ctx.connection_id() != my_connection_id); - Ok(()) - }; - addr_after_reconnect_result(run_checks()); - }) - .build() - .unwrap(); + let re_connection = build_connection( + DbConnection::builder() + .with_database_name(db_name_or_panic()) + .with_uri(LOCALHOST) + .on_connect_error(|_ctx, error| panic!("on_connect_error: {error:?}")) + .on_connect(move |ctx, _, _| { + reconnect_result(Ok(())); + let run_checks = || { + // A new connection should have a different connection id. + anyhow::ensure!(ctx.connection_id() != my_connection_id); + Ok(()) + }; + addr_after_reconnect_result(run_checks()); + }), + ) + .await; + #[cfg(not(target_arch = "wasm32"))] re_connection.run_threaded(); + #[cfg(target_arch = "wasm32")] + re_connection.run_background_task(); - reconnect_test_counter.wait_for_all(); + wait_for_all(&reconnect_test_counter).await; } -fn exec_caller_always_notified() { +async fn exec_caller_always_notified() { let test_counter = TestCounter::new(); let no_op_result = test_counter.add_test("notified_of_no_op_reducer"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; connection .reducers @@ -1857,17 +2099,17 @@ fn exec_caller_always_notified() { }) .unwrap(); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } /// Duplicates the test `insert_primitive`, /// but using `SubscriptionBuilder::subscribe_to_all_tables` rather than an explicit query set. -fn exec_subscribe_all_select_star() { +async fn exec_subscribe_all_select_star() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; connection .subscription_builder() @@ -1901,14 +2143,14 @@ fn exec_subscribe_all_select_star() { .on_error(|_, e| panic!("Subscription error: {e:?}")) .subscribe_to_all_tables(); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } -fn exec_sorted_uuids_insert() { +async fn exec_sorted_uuids_insert() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("sorted-uuids-insert"); - let connection = connect(&test_counter); + let connection = connect(&test_counter).await; subscribe_all_then(&connection, { let test_counter = test_counter.clone(); @@ -1938,10 +2180,10 @@ fn exec_sorted_uuids_insert() { } }); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } -fn exec_caller_alice_receives_reducer_callback_but_not_bob() { +async fn exec_caller_alice_receives_reducer_callback_but_not_bob() { fn check_val(val: T, eq: T) -> anyhow::Result<()> { (val == eq) .then_some(()) @@ -1952,23 +2194,22 @@ fn exec_caller_alice_receives_reducer_callback_but_not_bob() { let pre_ins_counter = TestCounter::new(); // Have two actors, Alice (0) and Bob (1), connect to the module. - // For each actor, subscribe to the `OneU8` table. - // The choice of table is a fairly random one: just one of the simpler tables. - let conns = ["alice", "bob"].map(|who| { - let conn = connect_with_then(&pre_ins_counter, who, |b| b, |_| {}); + async fn connect_actor( + who: &'static str, + pre_ins_counter: &Arc, + counter: &Arc, + ) -> ManagedConnection { + let conn = connect_with_then(pre_ins_counter, who, |b| b, |_| {}).await; let sub_applied = pre_ins_counter.add_test(format!("sub_applied_{who}")); - - let counter2 = counter.clone(); + let counter = counter.clone(); subscribe_all_then(&conn, move |ctx| { - sub_applied(Ok(())); - // Test that we are notified when a row is inserted. let db = ctx.db(); - let mut one_u8_inserted = Some(counter2.add_test(format!("one_u8_inserted_{who}"))); + let mut one_u8_inserted = Some(counter.add_test(format!("one_u8_inserted_{who}"))); db.one_u_8().on_insert(move |_, row| { (one_u8_inserted.take().unwrap())(check_val(row.n, 42)); }); - let mut one_u16_inserted = Some(counter2.add_test(format!("one_u16_inserted_{who}"))); + let mut one_u16_inserted = Some(counter.add_test(format!("one_u16_inserted_{who}"))); let is_alice = who == "alice"; db.one_u_16().on_insert(move |event, row| { let run_checks = || { @@ -1991,15 +2232,29 @@ fn exec_caller_alice_receives_reducer_callback_but_not_bob() { }; (one_u16_inserted.take().unwrap())(run_checks()); }); + + // Mark subscription readiness only after callback handlers are installed. + // On wasm this prevents a race where reducer calls can run before handlers + // are fully registered in the single-threaded event loop. + // The test uses `pre_ins_counter` as the "both clients are safe to use now" barrier. + // In this harness, "safe" must mean the handlers under test are already installed. + // Signalling readiness any earlier would allow Alice to fire the reducer while Bob's + // callbacks still do not exist, which would turn the test into an event-loop race. + sub_applied(Ok(())); }); conn - }); + } + + let conns = [ + connect_actor("alice", &pre_ins_counter, &counter).await, + connect_actor("bob", &pre_ins_counter, &counter).await, + ]; // Ensure both have finished connecting // and finished subscribing so that there isn't a race condition // between Alice executing the reducer and Bob being connected // or Alice executing the reducer and either having subscriptions applied. - pre_ins_counter.wait_for_all(); + wait_for_all(&pre_ins_counter).await; // Alice executes a reducer. // This should cause a row callback to be received by Alice and Bob. @@ -2022,7 +2277,7 @@ fn exec_caller_alice_receives_reducer_callback_but_not_bob() { .insert_one_u_16_then(24, reducer_callback_assert_committed("insert_one_u_16")) .unwrap(); - counter.wait_for_all(); + wait_for_all(&counter).await; // For the integrity of the test, ensure that Alice != Bob. // We do this after `run_threaded` so that the ids have been filled. @@ -2039,7 +2294,7 @@ fn put_result(result: &mut Option, res: Result<(), anyhow::Error (result.take().unwrap())(res); } -fn exec_row_deduplication() { +async fn exec_row_deduplication() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -2092,9 +2347,10 @@ fn exec_row_deduplication() { sub_applied_nothing_result(assert_all_tables_empty(ctx)); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; // Ensure we're not double counting anything. let table = conn.db.pk_u_32(); @@ -2103,7 +2359,7 @@ fn exec_row_deduplication() { assert_eq!(table.n().find(&42), Some(PkU32 { n: 42, data: 0xfeeb })); } -fn exec_row_deduplication_join_r_and_s() { +async fn exec_row_deduplication_join_r_and_s() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -2163,12 +2419,13 @@ fn exec_row_deduplication_join_r_and_s() { put_result(&mut unique_u32_on_insert_result, Ok(())); }); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } -fn exec_row_deduplication_r_join_s_and_r_join_t() { +async fn exec_row_deduplication_r_join_s_and_r_join_t() { let test_counter: Arc = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); @@ -2223,15 +2480,16 @@ fn exec_row_deduplication_r_join_s_and_r_join_t() { UniqueU32::on_delete(ctx, move |_, _| panic!()); PkU32Two::on_delete(ctx, move |_, _| panic!()); } - }); + }) + .await; - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; assert_eq!(count_unique_u32_on_insert.load(Ordering::SeqCst), 1); } /// This test asserts that the correct callbacks are invoked when updating the lhs table of a join -fn test_lhs_join_update() { +async fn test_lhs_join_update() { let insert_counter = TestCounter::new(); let update_counter = TestCounter::new(); let mut on_update_1 = Some(update_counter.add_test("on_update_1")); @@ -2239,18 +2497,21 @@ fn test_lhs_join_update() { let mut on_insert_1 = Some(insert_counter.add_test("on_insert_1")); let mut on_insert_2 = Some(insert_counter.add_test("on_insert_2")); - let conn = Arc::new(connect_then(&update_counter, { - move |ctx| { - subscribe_these_then( - ctx, - &[ - "SELECT p.* FROM pk_u_32 p WHERE n = 1", - "SELECT p.* FROM pk_u_32 p JOIN unique_u_32 u ON p.n = u.n WHERE u.data > 0 AND u.data < 5", - ], - |_| {}, - ); - } - })); + let conn = Arc::new( + connect_then(&update_counter, { + move |ctx| { + subscribe_these_then( + ctx, + &[ + "SELECT p.* FROM pk_u_32 p WHERE n = 1", + "SELECT p.* FROM pk_u_32 p JOIN unique_u_32 u ON p.n = u.n WHERE u.data > 0 AND u.data < 5", + ], + |_| {}, + ); + } + }) + .await, + ); // Add two pk_u32 rows to the subscription conn.reducers @@ -2274,7 +2535,7 @@ fn test_lhs_join_update() { // Wait for the subscription to be updated, // then update one of the pk_u32 rows. - insert_counter.wait_for_all(); + wait_for_all(&insert_counter).await; conn.reducers .update_pk_u_32_then(2, 1, move |ctx, outcome| { assert_outcome_committed("update_pk_u_32", outcome); @@ -2290,11 +2551,11 @@ fn test_lhs_join_update() { .unwrap(); // Wait for the second row update for pk_u32 - update_counter.wait_for_all(); + wait_for_all(&update_counter).await; } /// This test asserts that the correct callbacks are invoked when updating the lhs table of a join -fn test_lhs_join_update_disjoint_queries() { +async fn test_lhs_join_update_disjoint_queries() { let insert_counter = TestCounter::new(); let update_counter = TestCounter::new(); let mut on_update_1 = Some(update_counter.add_test("on_update_1")); @@ -2309,7 +2570,8 @@ fn test_lhs_join_update_disjoint_queries() { "SELECT p.* FROM pk_u_32 p JOIN unique_u_32 u ON p.n = u.n WHERE u.data > 0 AND u.data < 5 AND u.n != 1", ], |_| {}); } - })); + }) + .await); // Add two pk_u32 rows to the subscription conn.reducers @@ -2333,7 +2595,7 @@ fn test_lhs_join_update_disjoint_queries() { // Wait for the subscription to be updated, // then update one of the pk_u32 rows. - insert_counter.wait_for_all(); + wait_for_all(&insert_counter).await; conn.reducers .update_pk_u_32_then(2, 1, move |ctx, outcome| { assert_outcome_committed("update_pk_u_32", outcome); @@ -2349,7 +2611,7 @@ fn test_lhs_join_update_disjoint_queries() { .unwrap(); // Wait for the second row update for pk_u32 - update_counter.wait_for_all(); + wait_for_all(&update_counter).await; } /// Test that when subscribing to a single join query, @@ -2357,7 +2619,7 @@ fn test_lhs_join_update_disjoint_queries() { /// /// This is a regression test for [2397](https://github.com/clockworklabs/SpacetimeDB/issues/2397), /// where the server was incorrectly deduplicating incremental subscription updates. -fn test_intra_query_bag_semantics_for_join() { +async fn test_intra_query_bag_semantics_for_join() { let test_counter = TestCounter::new(); let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); let mut pk_u32_on_delete_result = Some(test_counter.add_test("pk_u32_on_delete")); @@ -2435,14 +2697,21 @@ fn test_intra_query_bag_semantics_for_join() { put_result(&mut pk_u32_on_delete_result, Ok(())); }); } - }); + }) + .await; + + // This test is only complete once both registered expectations have reported: + // `on_subscription_applied_nothing` confirms the initial state, and `pk_u32_on_delete` + // confirms the bag-semantics transition at the end. Without an explicit wait, the harness can + // finish before that final delete callback runs, which leaves the test under-synchronized. + wait_for_all(&test_counter).await; } /// Test that several clients subscribing to the same query and using the same protocol (bsatn) /// can use different compression algorithms than each other. /// /// This is a regression test. -fn exec_two_different_compression_algos() { +async fn exec_two_different_compression_algos() { use Compression::*; // Create 32 KiB of random bytes to make it very likely that compression is used. @@ -2456,17 +2725,17 @@ fn exec_two_different_compression_algos() { // Connect with brotli, gzip, and no compression. // One of them will insert and all of them will subscribe. // All should get back `bytes`. - fn connect_with_compression( + async fn connect_with_compression( test_counter: &Arc, compression_name: &str, compression: Compression, mut recorder: Option, - barrier: &Arc, + subscribed_clients: Arc, + row_inserted: Arc, expected: &Arc<[u8]>, ) { let expected1 = expected.clone(); let expected2 = expected1.clone(); - let barrier = barrier.clone(); connect_with_then( test_counter, compression_name, @@ -2485,57 +2754,95 @@ fn exec_two_different_compression_algos() { put_result(&mut recorder, res) }); - // All clients must have subscribed and registered the `on_insert` callback - // before we actually insert the row. - barrier.wait(); - - if compression == None { + // We cannot block inside wasm callbacks (e.g. with `Barrier::wait`), + // because that stalls the event loop and prevents other subscriptions + // from applying. Instead we atomically gate on the third subscriber. + // All three subscriptions must be live before any client inserts the row under + // test, otherwise the last client could miss the only insert. We coordinate + // with atomics instead of a blocking barrier because callback code shares the + // single-threaded wasm event loop with subscription delivery. + if subscribed_clients.fetch_add(1, Ordering::SeqCst) + 1 == 3 + && !row_inserted.swap(true, Ordering::SeqCst) + { VecU8::insert(ctx, expected2.to_vec()); } }) }, - ); + ) + .await; } let test_counter: Arc = TestCounter::new(); - let barrier = Arc::new(Barrier::new(3)); + let subscribed_clients = Arc::new(AtomicUsize::new(0)); + let row_inserted = Arc::new(AtomicBool::new(false)); let got_brotli = Some(test_counter.add_test("got_right_row_brotli")); let got_gzip = Some(test_counter.add_test("got_right_row_gzip")); let got_none = Some(test_counter.add_test("got_right_row_none")); - connect_with_compression(&test_counter, "brotli", Brotli, got_brotli, &barrier, &bytes); - connect_with_compression(&test_counter, "gzip", Gzip, got_gzip, &barrier, &bytes); - connect_with_compression(&test_counter, "none", None, got_none, &barrier, &bytes); - test_counter.wait_for_all(); + connect_with_compression( + &test_counter, + "brotli", + Brotli, + got_brotli, + subscribed_clients.clone(), + row_inserted.clone(), + &bytes, + ) + .await; + connect_with_compression( + &test_counter, + "gzip", + Gzip, + got_gzip, + subscribed_clients.clone(), + row_inserted.clone(), + &bytes, + ) + .await; + connect_with_compression( + &test_counter, + "none", + None, + got_none, + subscribed_clients, + row_inserted, + &bytes, + ) + .await; + wait_for_all(&test_counter).await; } /// In this test we have two clients issue parameterized subscriptions. /// These subscriptions are identical syntactically but not semantically, /// because they are parameterized by `:sender` - the caller's identity. -fn test_parameterized_subscription() { +async fn test_parameterized_subscription() { + // This test cares about per-client query parameterization, not about both clients reaching a + // global barrier at exactly the same moment. Each client only mutates state after its own + // subscription is applied, which is the point at which that client's visibility guarantee + // becomes meaningful on both native and wasm runtimes. let ctr_for_test = TestCounter::new(); - let ctr_for_subs = TestCounter::new(); - let sub_0 = Some(ctr_for_subs.add_test("sub_0")); - let sub_1 = Some(ctr_for_subs.add_test("sub_1")); + let sub_0 = Some(ctr_for_test.add_test("sub_0")); + let sub_1 = Some(ctr_for_test.add_test("sub_1")); let insert_0 = Some(ctr_for_test.add_test("insert_0")); let insert_1 = Some(ctr_for_test.add_test("insert_1")); let update_0 = Some(ctr_for_test.add_test("update_0")); let update_1 = Some(ctr_for_test.add_test("update_1")); - fn subscribe_and_update( + async fn subscribe_and_update( test_name: &str, old: i32, new: i32, - waiters: [Arc; 2], + ctr_for_test: Arc, senders: [Option; 3], ) { - let [ctr_for_test, ctr_for_subs] = waiters; let [mut record_sub, mut record_ins, mut record_upd] = senders; connect_with_then(&ctr_for_test, test_name, |builder| builder, { move |ctx| { let sender = ctx.identity(); subscribe_these_then(ctx, &["SELECT * FROM pk_identity WHERE i = :sender"], move |ctx| { put_result(&mut record_sub, Ok(())); - // Wait to insert until both client connections have been made - ctr_for_subs.wait_for_all(); + // Trigger the writes from inside this caller's subscription-applied callback + // so we know the parameterized query is active before the rows are created. + // Callback code must stay non-blocking here because wasm delivers both the + // subscription event and subsequent row callbacks on the same event loop. PkIdentity::insert(ctx, sender, old); PkIdentity::update(ctx, sender, new); }); @@ -2552,24 +2859,13 @@ fn test_parameterized_subscription() { put_result(&mut record_upd, Ok(())); }); } - }); + }) + .await; } - subscribe_and_update( - "client_0", - 1, - 2, - [ctr_for_test.clone(), ctr_for_subs.clone()], - [sub_0, insert_0, update_0], - ); - subscribe_and_update( - "client_1", - 3, - 4, - [ctr_for_test.clone(), ctr_for_subs.clone()], - [sub_1, insert_1, update_1], - ); - ctr_for_test.wait_for_all(); + subscribe_and_update("client_0", 1, 2, ctr_for_test.clone(), [sub_0, insert_0, update_0]).await; + subscribe_and_update("client_1", 3, 4, ctr_for_test.clone(), [sub_1, insert_1, update_1]).await; + wait_for_all(&ctr_for_test).await; } /// In this test we have two clients subscribe to the `users` table. @@ -2581,21 +2877,22 @@ fn test_parameterized_subscription() { /// ); /// ``` /// Hence each client should receive different rows. -fn test_rls_subscription() { +async fn test_rls_subscription() { + // Same principle as `test_parameterized_subscription`: the property under test is "what rows + // is this client allowed to observe once its subscription is active?", not "did two clients + // reach a synchronization barrier at the same instant?". let ctr_for_test = TestCounter::new(); - let ctr_for_subs = TestCounter::new(); - let sub_0 = Some(ctr_for_subs.add_test("sub_0")); - let sub_1 = Some(ctr_for_subs.add_test("sub_1")); + let sub_0 = Some(ctr_for_test.add_test("sub_0")); + let sub_1 = Some(ctr_for_test.add_test("sub_1")); let ins_0 = Some(ctr_for_test.add_test("insert_0")); let ins_1 = Some(ctr_for_test.add_test("insert_1")); - fn subscribe_and_update( + async fn subscribe_and_update( test_name: &str, user_name: &str, - waiters: [Arc; 2], + ctr_for_test: Arc, senders: [Option; 2], ) { - let [ctr_for_test, ctr_for_subs] = waiters; let [mut record_sub, mut record_ins] = senders; let user_name = user_name.to_owned(); let expected_name = user_name.to_owned(); @@ -2605,8 +2902,9 @@ fn test_rls_subscription() { let expected_identity = sender; subscribe_these_then(ctx, &["SELECT * FROM users"], move |ctx| { put_result(&mut record_sub, Ok(())); - // Wait to insert until both client connections have been made - ctr_for_subs.wait_for_all(); + // Invoke the reducer only after this client's RLS-filtered subscription is + // active. As above, callback code must remain non-blocking so wasm can keep + // delivering websocket and row events on the same event loop. ctx.reducers .insert_user_then(user_name, sender, reducer_callback_assert_committed("insert_user")) .unwrap(); @@ -2617,25 +2915,16 @@ fn test_rls_subscription() { put_result(&mut record_ins, Ok(())); }); } - }); + }) + .await; } - subscribe_and_update( - "client_0", - "Alice", - [ctr_for_test.clone(), ctr_for_subs.clone()], - [sub_0, ins_0], - ); - subscribe_and_update( - "client_1", - "Bob", - [ctr_for_test.clone(), ctr_for_subs.clone()], - [sub_1, ins_1], - ); - ctr_for_test.wait_for_all(); + subscribe_and_update("client_0", "Alice", ctr_for_test.clone(), [sub_0, ins_0]).await; + subscribe_and_update("client_1", "Bob", ctr_for_test.clone(), [sub_1, ins_1]).await; + wait_for_all(&ctr_for_test).await; } -fn exec_pk_simple_enum() { +async fn exec_pk_simple_enum() { let test_counter: Arc = TestCounter::new(); let mut updated = Some(test_counter.add_test("updated")); connect_then(&test_counter, move |ctx| { @@ -2662,11 +2951,12 @@ fn exec_pk_simple_enum() { .insert_pk_simple_enum_then(a, data1, reducer_callback_assert_committed("insert_pk_simple_enum")) .unwrap(); }); - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } -fn exec_indexed_simple_enum() { +async fn exec_indexed_simple_enum() { let test_counter: Arc = TestCounter::new(); let mut updated = Some(test_counter.add_test("updated")); connect_then(&test_counter, move |ctx| { @@ -2695,8 +2985,9 @@ fn exec_indexed_simple_enum() { ) .unwrap(); }); - }); - test_counter.wait_for_all(); + }) + .await; + wait_for_all(&test_counter).await; } /// This tests for a bug we once had where the Rust client SDK would @@ -2709,7 +3000,7 @@ fn exec_indexed_simple_enum() { /// but would then see incremental updates from all of the queries. /// /// A simple reproducer is available at [https://github.com/lavirlifiliol/spacetime-repro]. -fn exec_overlapping_subscriptions() { +async fn exec_overlapping_subscriptions() { // First, a bit of setup: insert the row `{ n: 1, data: 0 }`, // and wait for it to be present. let setup_counter = TestCounter::new(); @@ -2728,9 +3019,10 @@ fn exec_overlapping_subscriptions() { let _ = status; }); call_insert_result(insert_result.map_err(|e| e.into())); - }); + }) + .await; - setup_counter.wait_for_all(); + wait_for_all(&setup_counter).await; let test_counter = TestCounter::new(); @@ -2780,5 +3072,5 @@ fn exec_overlapping_subscriptions() { .map_err(|e| e.into()), ); - test_counter.wait_for_all(); + wait_for_all(&test_counter).await; } diff --git a/sdks/rust/tests/test-counter/Cargo.toml b/sdks/rust/tests/test-counter/Cargo.toml index 969dc417cf6..fa0c9cebb33 100644 --- a/sdks/rust/tests/test-counter/Cargo.toml +++ b/sdks/rust/tests/test-counter/Cargo.toml @@ -9,6 +9,8 @@ license-file = "LICENSE" [dependencies] spacetimedb-data-structures.workspace = true anyhow.workspace = true +futures.workspace = true +gloo-timers = { version = "0.3.0", features = ["futures"] } [lints] workspace = true diff --git a/sdks/rust/tests/test-counter/src/lib.rs b/sdks/rust/tests/test-counter/src/lib.rs index 3eb8eb2f305..f433e823624 100644 --- a/sdks/rust/tests/test-counter/src/lib.rs +++ b/sdks/rust/tests/test-counter/src/lib.rs @@ -1,10 +1,9 @@ #![allow(clippy::disallowed_macros)] use spacetimedb_data_structures::map::{HashMap, HashSet}; -use std::{ - sync::{Arc, Condvar, Mutex}, - time::Duration, -}; +use std::sync::{Arc, Condvar, Mutex}; +#[cfg(not(target_arch = "wasm32"))] +use std::time::Duration; const TEST_TIMEOUT_SECS: u64 = 5 * 60; @@ -55,7 +54,31 @@ impl TestCounter { }) } + // Keep this legacy sync API for existing native-first callers. + // wasm callers should prefer `wait_for_all_async` so we do not + // block the JS event loop while waiting for callbacks. pub fn wait_for_all(&self) { + #[cfg(target_arch = "wasm32")] + futures::executor::block_on(self.wait_for_all_wasm_async()); + + #[cfg(not(target_arch = "wasm32"))] + self.wait_for_all_native(); + } + + pub async fn wait_for_all_async(&self) { + #[cfg(target_arch = "wasm32")] + { + // wasm/web test clients run callbacks on a single-threaded event loop, + // so waiting must be async to allow callback tasks to make progress. + self.wait_for_all_wasm_async().await; + } + + #[cfg(not(target_arch = "wasm32"))] + self.wait_for_all_native(); + } + + #[cfg(not(target_arch = "wasm32"))] + fn wait_for_all_native(&self) { let lock = self.inner.lock().expect("TestCounterInner Mutex is poisoned"); let (lock, timeout_result) = self .wait_until_done @@ -100,4 +123,73 @@ impl TestCounter { } } } + + #[cfg(target_arch = "wasm32")] + async fn wait_for_all_wasm_async(&self) { + use gloo_timers::future::TimeoutFuture; + + const WAIT_INTERVAL_MS: u32 = 10; + const MAX_WAIT_ITERATIONS: u32 = (TEST_TIMEOUT_SECS as u32 * 1000) / WAIT_INTERVAL_MS; + + // Native can block on a Condvar because callbacks keep moving on a different SDK thread. + // wasm/web does not have that escape hatch in this harness: the websocket/message loop and + // the test body share the same single-threaded JS event loop, so blocking here would stop + // callback delivery entirely. We poll with timer yields so websocket/callback tasks can + // continue to run, and then do the same final pass native uses to convert recorded failures + // into a panic. + let all_tests_finished = || { + let inner = self.inner.lock().expect("TestCounterInner Mutex is poisoned"); + inner.outcomes.len() == inner.registered.len() + }; + + let mut finished = false; + for _ in 0..MAX_WAIT_ITERATIONS { + if all_tests_finished() { + // We still need the final outcome pass below. Returning here would incorrectly + // treat recorded `Err(...)` test outcomes as success, including harness tests that + // intentionally exercise the failure path. + finished = true; + break; + } + TimeoutFuture::new(WAIT_INTERVAL_MS).await; + } + + let lock = self.inner.lock().expect("TestCounterInner Mutex is poisoned"); + if !finished || lock.outcomes.len() != lock.registered.len() { + let mut timeout_count = 0; + let mut failed_count = 0; + for test in lock.registered.iter() { + match lock.outcomes.get(test) { + None => { + timeout_count += 1; + println!("TIMEOUT: {test}"); + } + Some(Err(e)) => { + failed_count += 1; + println!("FAILED: {test}:\n\t{e:?}\n"); + } + Some(Ok(())) => { + println!("PASSED: {test}"); + } + } + } + panic!("{timeout_count} tests timed out and {failed_count} tests failed"); + } else { + let mut failed_count = 0; + for (test, outcome) in lock.outcomes.iter() { + match outcome { + Ok(()) => println!("PASSED: {test}"), + Err(e) => { + failed_count += 1; + println!("FAILED: {test}:\n\t{e:?}\n"); + } + } + } + if failed_count != 0 { + panic!("{failed_count} tests failed"); + } else { + println!("All tests passed"); + } + } + } } diff --git a/sdks/rust/tests/test.rs b/sdks/rust/tests/test.rs index 02f0357a524..34a42fa344f 100644 --- a/sdks/rust/tests/test.rs +++ b/sdks/rust/tests/test.rs @@ -1,3 +1,55 @@ +#[cfg(feature = "sdk-tests-web-client")] +use std::path::Path; + +use spacetimedb_testing::sdk::TestBuilder; + +fn configure_test_client_commands( + builder: TestBuilder, + client_project: &str, + run_selector: Option<&str>, +) -> TestBuilder { + // Note: `run_selector` is intentionally interpreted differently by mode: + // - Native mode uses it as a CLI subcommand (`cargo run -- `), with `None` => `cargo run`. + // - Web mode forwards it to the wasm export `run(test_name)`, with `None` => empty string. + // This mirrors how `run_command` is consumed by the native vs web runners in `crates/testing/src/sdk.rs`. + #[cfg(feature = "sdk-tests-web-client")] + { + let package_name = Path::new(client_project) + .file_name() + .and_then(|name| name.to_str()) + .expect("client project path should end in a UTF-8 directory name"); + let artifact_name = package_name.replace('-', "_"); + + // Cargo workspace members emit into the workspace target directory, not each crate's local `./target`. + // Use CARGO_TARGET_DIR when set (e.g. in CI), otherwise fall back to `/target`. + let target_dir = std::env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../target") + .to_string_lossy() + .into_owned() + }); + let wasm_path = format!("{target_dir}/wasm32-unknown-unknown/debug/deps/{artifact_name}.wasm"); + let bindgen_out_dir = format!("target/sdk-test-web-bindgen/{package_name}"); + + builder + .with_compile_command("cargo build --target wasm32-unknown-unknown --no-default-features --features web") + .with_run_command(run_selector.unwrap_or_default()) + .with_web_client(wasm_path, bindgen_out_dir) + } + + #[cfg(not(feature = "sdk-tests-web-client"))] + { + let run_command = match run_selector { + Some(subcommand) => format!("cargo run -- {}", subcommand), + None => "cargo run".to_owned(), + }; + + builder + .with_compile_command("cargo build") + .with_run_command(run_command) + } +} + macro_rules! declare_tests_with_suffix { ($lang:ident, $suffix:literal) => { mod $lang { @@ -7,21 +59,23 @@ macro_rules! declare_tests_with_suffix { const CLIENT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/test-client"); fn make_test(subcommand: &str) -> Test { - Test::builder() - .with_name(subcommand) - .with_module(MODULE) - .with_client(CLIENT) - .with_language("rust") - // We test against multiple modules in different languages, - // and as of writing (pgoldman 2026-02-12), - // some of those languages have not yet been updated to make scheduled and lifecycle reducers - // private by default. As such, generating only public items results in different bindings - // depending on which module is the source. - .with_generate_private_items(true) - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command(format!("cargo run -- {}", subcommand)) - .build() + super::configure_test_client_commands( + Test::builder() + .with_name(subcommand) + .with_module(MODULE) + .with_client(CLIENT) + .with_language("rust") + // We test against multiple modules in different languages, + // and as of writing (pgoldman 2026-02-12), + // some of those languages have not yet been updated to make scheduled and lifecycle reducers + // private by default. As such, generating only public items results in different bindings + // depending on which module is the source. + .with_generate_private_items(true) + .with_bindings_dir("src/module_bindings"), + CLIENT, + Some(subcommand), + ) + .build() } #[test] @@ -197,25 +251,27 @@ macro_rules! declare_tests_with_suffix { #[test] fn connect_disconnect_callbacks() { - Test::builder() - .with_name(concat!("connect-disconnect-callback-", stringify!($lang))) - .with_module(concat!("sdk-test-connect-disconnect", $suffix)) - .with_client(concat!( - env!("CARGO_MANIFEST_DIR"), - "/tests/connect_disconnect_client" - )) - .with_language("rust") - // We test against multiple modules in different languages, - // and as of writing (pgoldman 2026-02-12), - // some of those languages have not yet been updated to make scheduled and lifecycle reducers - // private by default. As such, generating only public items results in different bindings - // depending on which module is the source. - .with_generate_private_items(true) - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command("cargo run") - .build() - .run(); + const CONNECT_DISCONNECT_CLIENT: &str = + concat!(env!("CARGO_MANIFEST_DIR"), "/tests/connect_disconnect_client"); + + super::configure_test_client_commands( + Test::builder() + .with_name(concat!("connect-disconnect-callback-", stringify!($lang))) + .with_module(concat!("sdk-test-connect-disconnect", $suffix)) + .with_client(CONNECT_DISCONNECT_CLIENT) + .with_language("rust") + // We test against multiple modules in different languages, + // and as of writing (pgoldman 2026-02-12), + // some of those languages have not yet been updated to make scheduled and lifecycle reducers + // private by default. As such, generating only public items results in different bindings + // depending on which module is the source. + .with_generate_private_items(true) + .with_bindings_dir("src/module_bindings"), + CONNECT_DISCONNECT_CLIENT, + None, + ) + .build() + .run(); } #[test] @@ -327,15 +383,17 @@ mod event_table_tests { const CLIENT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/event-table-client"); fn make_test(subcommand: &str) -> Test { - Test::builder() - .with_name(subcommand) - .with_module(MODULE) - .with_client(CLIENT) - .with_language("rust") - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command(format!("cargo run -- {}", subcommand)) - .build() + super::configure_test_client_commands( + Test::builder() + .with_name(subcommand) + .with_module(MODULE) + .with_client(CLIENT) + .with_language("rust") + .with_bindings_dir("src/module_bindings"), + CLIENT, + Some(subcommand), + ) + .build() } #[test] @@ -368,21 +426,23 @@ macro_rules! procedure_tests { const CLIENT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/procedure-client"); fn make_test(subcommand: &str) -> Test { - Test::builder() - .with_name(subcommand) - .with_module(MODULE) - .with_client(CLIENT) - .with_language("rust") - // We test against multiple modules in different languages, - // and as of writing (pgoldman 2026-02-12), - // some of those languages have not yet been updated to make scheduled and lifecycle reducers - // private by default. As such, generating only public items results in different bindings - // depending on which module is the source. - .with_generate_private_items(true) - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command(format!("cargo run -- {}", subcommand)) - .build() + super::configure_test_client_commands( + Test::builder() + .with_name(subcommand) + .with_module(MODULE) + .with_client(CLIENT) + .with_language("rust") + // We test against multiple modules in different languages, + // and as of writing (pgoldman 2026-02-12), + // some of those languages have not yet been updated to make scheduled and lifecycle reducers + // private by default. As such, generating only public items results in different bindings + // depending on which module is the source. + .with_generate_private_items(true) + .with_bindings_dir("src/module_bindings"), + CLIENT, + Some(subcommand), + ) + .build() } #[test] @@ -436,21 +496,23 @@ macro_rules! view_tests { const CLIENT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/view-client"); fn make_test(subcommand: &str) -> Test { - Test::builder() - .with_name(subcommand) - .with_module(MODULE) - .with_client(CLIENT) - .with_language("rust") - // We test against multiple modules in different languages, - // and as of writing (pgoldman 2026-02-12), - // some of those languages have not yet been updated to make scheduled and lifecycle reducers - // private by default. As such, generating only public items results in different bindings - // depending on which module is the source. - .with_generate_private_items(true) - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command(format!("cargo run -- {}", subcommand)) - .build() + super::configure_test_client_commands( + Test::builder() + .with_name(subcommand) + .with_module(MODULE) + .with_client(CLIENT) + .with_language("rust") + // We test against multiple modules in different languages, + // and as of writing (pgoldman 2026-02-12), + // some of those languages have not yet been updated to make scheduled and lifecycle reducers + // private by default. As such, generating only public items results in different bindings + // depending on which module is the source. + .with_generate_private_items(true) + .with_bindings_dir("src/module_bindings"), + CLIENT, + Some(subcommand), + ) + .build() } #[test] @@ -498,15 +560,17 @@ macro_rules! view_pk_tests { const CLIENT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/view-pk-client"); fn make_test(subcommand: &str) -> Test { - Test::builder() - .with_name(subcommand) - .with_module(MODULE) - .with_client(CLIENT) - .with_language("rust") - .with_bindings_dir("src/module_bindings") - .with_compile_command("cargo build") - .with_run_command(format!("cargo run -- {}", subcommand)) - .build() + super::configure_test_client_commands( + Test::builder() + .with_name(subcommand) + .with_module(MODULE) + .with_client(CLIENT) + .with_language("rust") + .with_bindings_dir("src/module_bindings"), + CLIENT, + Some(subcommand), + ) + .build() } #[test] diff --git a/sdks/rust/tests/view-client/Cargo.toml b/sdks/rust/tests/view-client/Cargo.toml index 76c6bb58a8e..5c57f11ff2f 100644 --- a/sdks/rust/tests/view-client/Cargo.toml +++ b/sdks/rust/tests/view-client/Cargo.toml @@ -6,12 +6,40 @@ license-file = "LICENSE" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = [ + "dep:env_logger", +] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:futures", +] + +[[bin]] +name = "view-client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } spacetimedb-lib.workspace = true test-counter = { path = "../test-counter" } anyhow.workspace = true -env_logger.workspace = true +env_logger = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/view-client/src/lib.rs b/sdks/rust/tests/view-client/src/lib.rs new file mode 100644 index 00000000000..9dd50d80e2b --- /dev/null +++ b/sdks/rust/tests/view-client/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(test_name: String) { + cli::dispatch(&test_name); +} diff --git a/sdks/rust/tests/view-client/src/main.rs b/sdks/rust/tests/view-client/src/main.rs index 96312b4f40e..d105ed2d555 100644 --- a/sdks/rust/tests/view-client/src/main.rs +++ b/sdks/rust/tests/view-client/src/main.rs @@ -1,4 +1,4 @@ -mod module_bindings; +pub(crate) mod module_bindings; use module_bindings::*; use spacetimedb_lib::Identity; @@ -10,6 +10,7 @@ const LOCALHOST: &str = "http://localhost:3000"; /// Register a panic hook which will exit the process whenever any thread panics. /// /// This allows us to fail tests by panicking in callbacks. +#[cfg(not(target_arch = "wasm32"))] fn exit_on_panic() { // The default panic hook is responsible for printing the panic message and backtrace to stderr. // Grab a handle on it, and invoke it in our custom hook before exiting. @@ -27,6 +28,7 @@ fn db_name_or_panic() -> String { std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") } +#[cfg(not(target_arch = "wasm32"))] fn main() { env_logger::init(); exit_on_panic(); @@ -35,6 +37,10 @@ fn main() { .nth(1) .expect("Pass a test name as a command-line argument to the test client"); + dispatch(&test); +} + +pub(crate) fn dispatch(test: &str) { match &*test { "view-anonymous-subscribe" => exec_anonymous_subscribe(), "view-anonymous-subscribe-with-query-builder" => exec_anonymous_subscribe_with_query_builder(), @@ -47,6 +53,16 @@ fn main() { } } +#[cfg(not(target_arch = "wasm32"))] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + builder.build().unwrap() +} + +#[cfg(target_arch = "wasm32")] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + futures::executor::block_on(builder.build()).unwrap() +} + fn connect_with_then( test_counter: &std::sync::Arc, on_connect_suffix: &str, @@ -63,8 +79,11 @@ fn connect_with_then( connected_result(Ok(())); }) .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")); - let conn = with_builder(builder).build().unwrap(); + let conn = build_connection(with_builder(builder)); + #[cfg(not(target_arch = "wasm32"))] conn.run_threaded(); + #[cfg(target_arch = "wasm32")] + conn.run_background_task(); conn } diff --git a/sdks/rust/tests/view-pk-client/Cargo.toml b/sdks/rust/tests/view-pk-client/Cargo.toml index f872b1e7f17..77658606a79 100644 --- a/sdks/rust/tests/view-pk-client/Cargo.toml +++ b/sdks/rust/tests/view-pk-client/Cargo.toml @@ -4,11 +4,39 @@ version.workspace = true edition.workspace = true license-file = "LICENSE" +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["native"] + +# Builds the existing CLI test client. +native = [ + "dep:env_logger", +] + +# Builds the client for wasm32-unknown-unknown using the Rust SDK `web` backend. +web = [ + "spacetimedb-sdk/web", + "dep:wasm-bindgen", + "dep:wasm-bindgen-futures", + "dep:futures", +] + +[[bin]] +name = "view-pk-client" +path = "src/main.rs" +required-features = ["native"] + [dependencies] spacetimedb-sdk = { path = "../.." } test-counter = { path = "../test-counter" } anyhow.workspace = true -env_logger.workspace = true +env_logger = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + +wasm-bindgen = { version = "0.2.100", optional = true } +wasm-bindgen-futures = { version = "0.4.45", optional = true } [lints] workspace = true diff --git a/sdks/rust/tests/view-pk-client/src/lib.rs b/sdks/rust/tests/view-pk-client/src/lib.rs new file mode 100644 index 00000000000..9dd50d80e2b --- /dev/null +++ b/sdks/rust/tests/view-pk-client/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(clippy::disallowed_macros)] + +#[path = "main.rs"] +mod cli; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +use wasm_bindgen::prelude::wasm_bindgen; + +#[cfg(all(target_arch = "wasm32", feature = "web"))] +#[wasm_bindgen] +pub async fn run(test_name: String) { + cli::dispatch(&test_name); +} diff --git a/sdks/rust/tests/view-pk-client/src/main.rs b/sdks/rust/tests/view-pk-client/src/main.rs index 38f73e5f91e..01e3b432dc3 100644 --- a/sdks/rust/tests/view-pk-client/src/main.rs +++ b/sdks/rust/tests/view-pk-client/src/main.rs @@ -1,14 +1,15 @@ -mod module_bindings; +pub(crate) mod module_bindings; use module_bindings::*; use spacetimedb_sdk::TableWithPrimaryKey; -use spacetimedb_sdk::{error::InternalError, DbContext}; +use spacetimedb_sdk::{error::InternalError, DbConnectionBuilder, DbContext}; use test_counter::TestCounter; const LOCALHOST: &str = "http://localhost:3000"; type ResultRecorder = Box)>; +#[cfg(not(target_arch = "wasm32"))] fn exit_on_panic() { let default_hook = std::panic::take_hook(); std::panic::set_hook(Box::new(move |panic_info| { @@ -21,6 +22,16 @@ fn db_name_or_panic() -> String { std::env::var("SPACETIME_SDK_TEST_DB_NAME").expect("Failed to read db name from env") } +#[cfg(not(target_arch = "wasm32"))] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + builder.build().unwrap() +} + +#[cfg(target_arch = "wasm32")] +fn build_connection(builder: DbConnectionBuilder) -> DbConnection { + futures::executor::block_on(builder.build()).unwrap() +} + fn put_result(result: &mut Option, res: Result<(), anyhow::Error>) { (result.take().unwrap())(res); } @@ -48,10 +59,12 @@ fn connect_then( callback(ctx); connected_result(Ok(())); }) - .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")) - .build() - .unwrap(); + .on_connect_error(|_ctx, error| panic!("Connect errored: {error:?}")); + let conn = build_connection(conn); + #[cfg(not(target_arch = "wasm32"))] conn.run_threaded(); + #[cfg(target_arch = "wasm32")] + conn.run_background_task(); conn } @@ -279,6 +292,7 @@ fn exec_view_pk_semijoin_two_sender_views_query_builder() { test_counter.wait_for_all(); } +#[cfg(not(target_arch = "wasm32"))] fn main() { env_logger::init(); exit_on_panic(); @@ -287,6 +301,10 @@ fn main() { .nth(1) .expect("Pass a test name as a command-line argument to the test client"); + dispatch(&test); +} + +pub(crate) fn dispatch(test: &str) { match &*test { "view-pk-on-update" => exec_view_pk_on_update(), "view-pk-join-query-builder" => exec_view_pk_join_query_builder(), diff --git a/tools/ci/src/main.rs b/tools/ci/src/main.rs index fa5d851807e..f942f950f38 100644 --- a/tools/ci/src/main.rs +++ b/tools/ci/src/main.rs @@ -338,6 +338,20 @@ fn main() -> Result<()> { "unreal" ) .run()?; + // Run the same SDK suite against wasm+web test clients. + cmd!( + "cargo", + "test", + "-p", + "spacetimedb-sdk", + "--features", + "allow_loopback_http_for_tests,sdk-tests-web-client", + "--", + "--test-threads=2", + "--skip", + "unreal" + ) + .run()?; // TODO: This should check for a diff at the start. If there is one, we should alert the user // that we're disabling diff checks because they have a dirty git repo, and to re-run in a clean one // if they want those checks.