diff --git a/Cargo.lock b/Cargo.lock index ec1f3551..9d26de75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,6 +741,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.16", +] [[package]] name = "figment" @@ -781,6 +784,18 @@ dependencies = [ "spin", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -940,8 +955,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1189,7 +1206,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2", "tokio", "tower-service", "tracing", @@ -1999,9 +2016,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -2009,9 +2026,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools", @@ -2022,9 +2039,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -2541,13 +2558,14 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.4.11" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae1eac4a748b11a2bb5b342bea8546085751cf9a45e30fb1276b072bb5541e6" +checksum = "5d3c4e8bca4c556eec616dc2594e519248891ca84f8bf958016c2c416223d8ff" dependencies = [ "prost", "prost-types", "tonic", + "tonic-prost", ] [[package]] @@ -2688,16 +2706,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -2902,7 +2910,7 @@ checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", "chrono", - "flume", + "flume 0.11.1", "futures-channel", "futures-core", "futures-executor", @@ -2992,6 +3000,7 @@ dependencies = [ "derive_builder", "elegant-departure", "figment", + "flume 0.12.0", "futures", "futures-util", "hex", @@ -3166,7 +3175,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", "windows-sys 0.59.0", ] @@ -3236,9 +3245,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", @@ -3253,8 +3262,8 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", "tower", @@ -3265,14 +3274,26 @@ dependencies = [ [[package]] name = "tonic-health" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb87334d340313fefa513b6e60794d44a86d5f039b523229c99c323e4e19ca4b" +checksum = "f4ff0636fef47afb3ec02818f5bceb4377b8abb9d6a386aeade18bd6212f8eb7" dependencies = [ "prost", "tokio", "tokio-stream", "tonic", + "tonic-prost", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 67119a08..a9758e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ clap = { version = "4.5.20", features = ["derive"] } derive_builder = "0.20.2" elegant-departure = { version = "0.3.1", features = ["tokio"] } figment = { version = "0.10.19", features = ["env", "yaml", "test"] } +flume = "0.12.0" futures = "0.3.31" futures-util = "0.3.31" hex = "0.4.3" @@ -26,8 +27,8 @@ http-body-util = "0.1.2" libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" -prost = "0.13" -prost-types = "0.13.3" +prost = "0.14" +prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } sentry = { version = "0.41.0", default-features = false, features = [ @@ -41,7 +42,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = "0.4.11" +sentry_protos = "0.8.5" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" @@ -49,8 +50,8 @@ sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "po tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" -tonic = "0.13" -tonic-health = "0.13" +tonic = "0.14" +tonic-health = "0.14" tower = "0.5.1" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = [ diff --git a/README.md b/README.md index 860f446b..7a4f79bc 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ The test suite is composed of unit and integration tests in Rust, and end-to-end ```bash # Run unit/integration tests -make test +make unit-test # Run end-to-end tests make integration-test diff --git a/benches/store_bench.rs b/benches/store_bench.rs index 7f876b45..2a09b51c 100644 --- a/benches/store_bench.rs +++ b/benches/store_bench.rs @@ -59,7 +59,7 @@ async fn get_pending_activations(num_activations: u32, num_workers: u32) { let mut num_activations_processed = 0; while store - .get_pending_activation(Some("sentry"), Some(&ns)) + .get_pending_activation(Some("sentry"), Some(std::slice::from_ref(&ns))) .await .unwrap() .is_some() diff --git a/integration_tests/integration_tests/helpers.py b/integration_tests/integration_tests/helpers.py index 61389292..c860e6ee 100644 --- a/integration_tests/integration_tests/helpers.py +++ b/integration_tests/integration_tests/helpers.py @@ -18,6 +18,9 @@ TASKBROKER_ROOT = Path(__file__).parent.parent.parent TASKBROKER_BIN = TASKBROKER_ROOT / "target/debug/taskbroker" TESTS_OUTPUT_ROOT = Path(__file__).parent.parent / ".tests_output" + +TASKBROKER_RESTART_PORT_DELAY_SEC = 1.0 + TEST_PRODUCER_CONFIG = { "bootstrap.servers": "127.0.0.1:9092", "broker.address.family": "v4", @@ -212,7 +215,7 @@ def get_num_tasks_group_by_status( def get_available_ports(count: int) -> list[int]: - MIN = 49152 + MIN = 50051 MAX = 65535 res = [] for i in range(count): diff --git a/integration_tests/integration_tests/test_consumer_rebalancing.py b/integration_tests/integration_tests/test_consumer_rebalancing.py index f4380d1e..b9d21732 100644 --- a/integration_tests/integration_tests/test_consumer_rebalancing.py +++ b/integration_tests/integration_tests/test_consumer_rebalancing.py @@ -10,6 +10,7 @@ from integration_tests.helpers import ( TASKBROKER_BIN, + TASKBROKER_RESTART_PORT_DELAY_SEC, TESTS_OUTPUT_ROOT, TaskbrokerConfig, create_topic, @@ -48,6 +49,13 @@ def manage_taskbroker( assert return_code == 0 except Exception: process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + + if i < iterations - 1: + time.sleep(TASKBROKER_RESTART_PORT_DELAY_SEC) def test_tasks_written_once_during_rebalancing() -> None: diff --git a/src/config.rs b/src/config.rs index 67b571d2..a547304e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,17 @@ pub enum DatabaseAdapter { Postgres, } +/// How the taskbroker delivers tasks to workers. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum DeliveryMode { + /// Workers pull tasks from the broker. + Pull, + + /// Broker pushes tasks to workers. + Push, +} + #[derive(PartialEq, Debug, Deserialize, Serialize)] pub struct Config { /// The sentry DSN to use for error reporting. @@ -239,6 +250,42 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, + + /// How to deliver tasks to workers: "push" or "pull". + pub delivery_mode: DeliveryMode, + + /// The number of concurrent dispatchers to run. + pub fetch_threads: usize, + + /// Time in milliseconds to wait between fetch attempts when no pending activation is found. + pub fetch_wait_ms: u64, + + /// The number of concurrent pushers each dispatcher should run. + pub push_threads: usize, + + /// The size of the push queue. + pub push_queue_size: usize, + + /// Maximum time in milliseconds to wait when submitting an activation to the push pool. + pub push_queue_timeout_ms: u64, + + /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. + pub push_timeout_ms: u64, + + /// The worker service endpoint. + pub worker_endpoint: String, + + /// The hostname used to construct `callback_url` for task push requests. + pub callback_addr: String, + + /// The port used to construct `callback_url` for task push requests. + pub callback_port: u32, + + /// Application filter for push mode. When set, only pending activations for this application are considered. + pub application: Option, + + /// List of namespaces for push mode. When set, application must also be set (store requirement). + pub namespaces: Option>, } impl Default for Config { @@ -308,6 +355,18 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, + delivery_mode: DeliveryMode::Pull, + fetch_threads: 1, + fetch_wait_ms: 100, + push_threads: 1, + push_queue_size: 1, + push_queue_timeout_ms: 5000, + push_timeout_ms: 30000, + worker_endpoint: "http://127.0.0.1:50052".into(), + callback_addr: "0.0.0.0".into(), + callback_port: 50051, + application: None, + namespaces: None, } } } @@ -422,7 +481,7 @@ impl Provider for Config { mod tests { use std::{borrow::Cow, collections::BTreeMap}; - use super::{Config, DatabaseAdapter}; + use super::{Config, DatabaseAdapter, DeliveryMode}; use crate::{Args, logging::LogFormat}; use figment::Jail; @@ -712,4 +771,127 @@ mod tests { Ok(()) }); } + + #[test] + fn test_default_delivery_mode() { + let config = Config::default(); + assert_eq!(config.delivery_mode, DeliveryMode::Pull); + } + + #[test] + fn test_from_args_delivery_mode_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_DELIVERY_MODE", "push"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.delivery_mode, DeliveryMode::Push); + + Ok(()) + }); + } + + #[test] + fn test_from_args_delivery_mode_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file("config.yaml", "delivery_mode: push")?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.delivery_mode, DeliveryMode::Push); + + Ok(()) + }); + } + + #[test] + fn test_default_push_callback_fields() { + let config = Config::default(); + assert_eq!(config.callback_addr, "0.0.0.0"); + assert_eq!(config.callback_port, 50051); + } + + #[test] + fn test_from_args_push_callback_fields_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_CALLBACK_ADDR", "127.0.0.1"); + jail.set_env("TASKBROKER_CALLBACK_PORT", "51000"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "127.0.0.1"); + assert_eq!(config.callback_port, 51000); + + Ok(()) + }); + } + + #[test] + fn test_from_args_push_callback_fields_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + callback_addr: 10.0.0.1 + callback_port: 52000 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "10.0.0.1"); + assert_eq!(config.callback_port, 52000); + + Ok(()) + }); + } + + #[test] + fn test_default_application_and_namespaces() { + let config = Config::default(); + assert_eq!(config.application, None); + assert_eq!(config.namespaces, None); + } + + #[test] + fn test_from_args_application_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_APPLICATION", "getsentry"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.application.as_deref(), Some("getsentry")); + assert_eq!(config.namespaces, None); + + Ok(()) + }); + } + + #[test] + fn test_from_args_application_and_namespaces_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + application: getsentry + namespaces: + - ns1 + - ns2 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.application.as_deref(), Some("getsentry")); + assert_eq!(config.namespaces, Some(vec!["ns1".into(), "ns2".into()])); + + Ok(()) + }); + } } diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs new file mode 100644 index 00000000..5fb81715 --- /dev/null +++ b/src/fetch/mod.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use tokio::time::sleep; +use tonic::async_trait; +use tracing::{debug, info, warn}; + +use crate::config::Config; +use crate::push::{PushError, PushPool}; +use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; + +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +#[async_trait] +pub trait TaskPusher { + /// Push a single task to the worker service. + async fn push_task(&self, activation: InflightActivation) -> Result<(), PushError>; +} + +#[async_trait] +impl TaskPusher for PushPool { + async fn push_task(&self, activation: InflightActivation) -> Result<(), PushError> { + self.submit(activation).await + } +} + +/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. +pub struct FetchPool { + /// Inflight activation store. + store: Arc, + + /// Pool of push threads that push activations to the worker service. + pusher: Arc, + + /// Taskbroker configuration. + config: Arc, +} + +impl FetchPool { + /// Initialize a new fetch pool. + pub fn new( + store: Arc, + config: Arc, + pusher: Arc, + ) -> Self { + Self { + store, + config, + pusher, + } + } + + /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let fetch_wait_ms = self.config.fetch_wait_ms; + + let mut fetch_pool = crate::tokio::spawn_pool(self.config.fetch_threads, |_| { + let store = self.store.clone(); + let pusher = self.pusher.clone(); + let config = self.config.clone(); + + let guard = get_shutdown_guard().shutdown_on_drop(); + + async move { + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + return; + } + + _ = async { + debug!("Fetching next pending activation..."); + metrics::counter!("fetch.loop.count").increment(1); + + let start = Instant::now(); + let mut backoff = false; + + let application = config.application.as_deref(); + let namespaces = config.namespaces.as_deref(); + + match store.get_pending_activation(application, namespaces).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + + debug!( + task_id = %id, + "Fetched and marked task as processing" + ); + + if let Err(e) = pusher.push_task(activation).await { + match e { + PushError::Timeout => warn!( + task_id = %id, + "Submit to push pool timed out after milliseconds", + ), + + PushError::Channel(e) => warn!( + task_id = %id, + error = ?e, + "Submit to push pool failed due to channel error", + ) + } + + backoff = true; + } + } + + Ok(None) => { + debug!("No pending activations"); + + // Wait for pending activations to appear + backoff = true; + } + + Err(e) => { + warn!( + error = ?e, + "Store failed while fetching task" + ); + + // Store may be down, wait before trying again + backoff = true; + } + }; + + metrics::histogram!("fetch.loop.duration") + .record(start.elapsed()); + + if backoff { + sleep(Duration::from_millis(fetch_wait_ms)).await; + } + } => {} + } + } + } + }); + + while let Some(res) = fetch_pool.join_next().await { + if let Err(e) = res { + return Err(e.into()); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests; diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs new file mode 100644 index 00000000..a77a9aac --- /dev/null +++ b/src/fetch/tests.rs @@ -0,0 +1,260 @@ +use std::sync::Arc; + +use anyhow::{Error, anyhow}; +use chrono::{DateTime, Utc}; +use tokio::sync::Mutex; +use tokio::time::{Duration, sleep}; +use tonic::async_trait; + +use super::*; +use crate::config::Config; +use crate::push::PushError; +use crate::store::inflight_activation::InflightActivation; +use crate::store::inflight_activation::InflightActivationStore; +use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivationStatus, QueryResult, +}; +use crate::test_utils::make_activations; + +/// Store stub that returns one activation once OR is always empty OR always fails. +struct MockStore { + /// A single (optional) pending activation. + pending: Mutex>, + + /// Should operations fail? + fail: bool, +} + +impl MockStore { + fn empty() -> Self { + Self { + pending: Mutex::new(None), + fail: false, + } + } + + fn one(activation: InflightActivation) -> Self { + Self { + pending: Mutex::new(Some(activation)), + fail: false, + } + } + + fn error() -> Self { + Self { + pending: Mutex::new(None), + fail: true, + } + } +} + +#[async_trait] +impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activation( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + ) -> Result, Error> { + if self.fail { + return Err(anyhow!("mock store error")); + } + + Ok(self.pending.lock().await.take()) + } + + async fn get_pending_activations( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> Result { + unimplemented!() + } +} + +/// Records task IDs passed to `push_task`. If `fail` is true, returns an error. +struct RecordingPusher { + /// What IDs have been pushed? + pushed_ids: Mutex>, + + /// Should pushing fail? + fail: bool, +} + +impl RecordingPusher { + fn new(fail: bool) -> Self { + let pushed_ids = Mutex::new(vec![]); + Self { pushed_ids, fail } + } +} + +#[async_trait] +impl TaskPusher for RecordingPusher { + async fn push_task(&self, activation: InflightActivation) -> Result<(), PushError> { + self.pushed_ids.lock().await.push(activation.id.clone()); + + if self.fail { + return Err(PushError::Timeout); + } + + Ok(()) + } +} + +fn test_config() -> Arc { + Arc::new(Config { + fetch_threads: 1, + fetch_wait_ms: 5, + ..Config::default() + }) +} + +#[tokio::test] +async fn fetch_pool_delivers_activation_to_pusher() { + let activation = make_activations(1).remove(0); + let store: Arc = Arc::new(MockStore::one(activation.clone())); + let pusher = Arc::new(RecordingPusher::new(false)); + + let pool = FetchPool::new(store, test_config(), pusher.clone()); + let handle = tokio::spawn(async move { pool.start().await }); + + sleep(Duration::from_millis(200)).await; + assert_eq!(pusher.pushed_ids.lock().await.as_slice(), &[activation.id]); + + handle.abort(); +} + +#[tokio::test] +async fn fetch_pool_calls_pusher_once_when_push_errors() { + let activation = make_activations(1).remove(0); + let store: Arc = Arc::new(MockStore::one(activation)); + let pusher = Arc::new(RecordingPusher::new(true)); + + let pool = FetchPool::new(store, test_config(), pusher.clone()); + let handle = tokio::spawn(async move { pool.start().await }); + + sleep(Duration::from_millis(100)).await; + assert_eq!(pusher.pushed_ids.lock().await.len(), 1); + + handle.abort(); +} + +#[tokio::test] +async fn fetch_pool_skips_pusher_when_store_errors() { + let store: Arc = Arc::new(MockStore::error()); + let pusher = Arc::new(RecordingPusher::new(false)); + + let pool = FetchPool::new(store, test_config(), pusher.clone()); + let handle = tokio::spawn(async move { pool.start().await }); + + sleep(Duration::from_millis(80)).await; + assert!(pusher.pushed_ids.lock().await.is_empty()); + + handle.abort(); +} + +#[tokio::test] +async fn fetch_pool_skips_pusher_when_no_pending() { + let store: Arc = Arc::new(MockStore::empty()); + let pusher = Arc::new(RecordingPusher::new(false)); + + let pool = FetchPool::new(store, test_config(), pusher.clone()); + let handle = tokio::spawn(async move { pool.start().await }); + + sleep(Duration::from_millis(80)).await; + assert!(pusher.pushed_ids.lock().await.is_empty()); + + handle.abort(); +} diff --git a/src/grpc/server.rs b/src/grpc/server.rs index f5ac9292..03db46ba 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -9,11 +9,13 @@ use std::sync::Arc; use std::time::Instant; use tonic::{Request, Response, Status}; +use crate::config::{Config, DeliveryMode}; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; pub struct TaskbrokerServer { pub store: Arc, + pub config: Arc, } #[tonic::async_trait] @@ -23,12 +25,21 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { + if self.config.delivery_mode == DeliveryMode::Push { + return Err(Status::permission_denied( + "Cannot call while broker is in PUSH mode", + )); + } + let start_time = Instant::now(); + let application = &request.get_ref().application; let namespace = &request.get_ref().namespace; + let namespaces = namespace.as_ref().map(std::slice::from_ref); + let inflight = self .store - .get_pending_activation(application.as_deref(), namespace.as_deref()) + .get_pending_activation(application.as_deref(), namespaces) .await; match inflight { @@ -98,6 +109,10 @@ impl ConsumerService for TaskbrokerServer { } metrics::histogram!("grpc_server.set_status.duration").record(start_time.elapsed()); + if self.config.delivery_mode == DeliveryMode::Push { + return Ok(Response::new(SetTaskStatusResponse { task: None })); + } + let Some(FetchNextTask { ref namespace, ref application, @@ -107,9 +122,10 @@ impl ConsumerService for TaskbrokerServer { }; let start_time = Instant::now(); + let namespaces = namespace.as_ref().map(std::slice::from_ref); let res = match self .store - .get_pending_activation(application.as_deref(), namespace.as_deref()) + .get_pending_activation(application.as_deref(), namespaces) .await { Err(e) => { diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b99911a2..f0834a95 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use crate::config::{Config, DeliveryMode}; use crate::grpc::server::TaskbrokerServer; use prost::Message; use rstest::rstest; @@ -7,7 +10,28 @@ use sentry_protos::taskbroker::v1::{ }; use tonic::{Code, Request}; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::{create_config, create_test_store, make_activations}; + +#[tokio::test] +async fn test_get_task_push_mode_returns_permission_denied() { + let store = create_test_store("sqlite").await; + let config = Arc::new(Config { + delivery_mode: DeliveryMode::Push, + ..Config::default() + }); + + let service = TaskbrokerServer { store, config }; + let request = GetTaskRequest { + namespace: None, + application: None, + }; + let response = service.get_task(Request::new(request)).await; + + assert!(response.is_err()); + let e = response.unwrap_err(); + assert_eq!(e.code(), Code::PermissionDenied); + assert_eq!(e.message(), "Cannot call while broker is in PUSH mode"); +} #[tokio::test] #[rstest] @@ -15,7 +39,9 @@ use crate::test_utils::{create_test_store, make_activations}; #[case::postgres("postgres")] async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -34,7 +60,9 @@ async fn test_get_task(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -53,7 +81,9 @@ async fn test_set_task_status(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -76,10 +106,12 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -99,6 +131,8 @@ async fn test_get_task_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_application_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,7 +142,7 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), @@ -129,12 +163,14 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: Some(namespace), application: None, @@ -153,10 +189,12 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str #[allow(deprecated)] async fn test_set_task_status_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, @@ -192,6 +230,8 @@ async fn test_set_task_status_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -201,7 +241,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -229,6 +269,8 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -238,7 +280,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -261,12 +303,14 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete diff --git a/src/lib.rs b/src/lib.rs index 33567944..6ce53cd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,16 @@ use clap::Parser; use std::fs; pub mod config; +pub mod fetch; pub mod grpc; pub mod kafka; pub mod logging; pub mod metrics; +pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; +pub mod tokio; pub mod upkeep; /// Name of the grpc service. diff --git a/src/main.rs b/src/main.rs index 7970939d..db0ebd84 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,9 +2,11 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; +use taskbroker::fetch::FetchPool; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; +use taskbroker::push::PushPool; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; @@ -15,7 +17,7 @@ use tracing::{debug, error, info, warn}; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; use taskbroker::SERVICE_NAME; -use taskbroker::config::{Config, DatabaseAdapter}; +use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; use taskbroker::grpc::server::TaskbrokerServer; @@ -39,16 +41,16 @@ use taskbroker::store::postgres_activation_store::{ use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; -async fn log_task_completion(name: &str, task: JoinHandle>) { +async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { Ok(Ok(())) => { - info!("Task {} completed", name); + info!("Task {} completed", name.as_ref()); } Ok(Err(e)) => { - error!("Task {} failed: {:?}", name, e); + error!("Task {} failed: {:?}", name.as_ref(), e); } Err(e) => { - error!("Task {} panicked: {:?}", name, e); + error!("Task {} panicked: {:?}", name.as_ref(), e); } } } @@ -192,6 +194,7 @@ async fn main() -> Result<(), Error> { let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); + async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) .parse() @@ -206,6 +209,7 @@ async fn main() -> Result<(), Error> { .layer(layers) .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, + config: grpc_config, })) .add_service(health_service.clone()) .serve(addr); @@ -236,7 +240,25 @@ async fn main() -> Result<(), Error> { } }); - elegant_departure::tokio::depart() + // Initialize push and fetch pools + let push_pool = Arc::new(PushPool::new(config.clone())); + let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); + + // Initialize push threads + let push_task = if config.delivery_mode == DeliveryMode::Push { + Some(tokio::spawn(async move { push_pool.start().await })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.delivery_mode == DeliveryMode::Push { + Some(tokio::spawn(async move { fetch_pool.start().await })) + } else { + None + }; + + let mut departure = elegant_departure::tokio::depart() .on_termination() .on_sigint() .on_signal(SignalKind::hangup()) @@ -244,8 +266,16 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("consumer", consumer_task)) .on_completion(log_task_completion("grpc_server", grpc_server_task)) .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)) - .await; + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs new file mode 100644 index 00000000..60408130 --- /dev/null +++ b/src/push/mod.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use flume::{Receiver, SendError, Sender}; +use prost::Message; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tokio::task::JoinSet; +use tonic::async_trait; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::InflightActivation; + +/// Error returned when enqueueing an activation for the push workers fails. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum PushError { + /// The bounded queue stayed full until `push_queue_timeout_ms` elapsed. + Timeout, + + /// Channel disconnected (no receivers) or another failure. + Channel(SendError), +} + +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +trait WorkerClient { + /// Send a single `PushTaskRequest` to the worker service. + async fn send(&mut self, request: PushTaskRequest) -> Result<()>; +} + +#[async_trait] +impl WorkerClient for WorkerServiceClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.push_task(request).await?; + Ok(()) + } +} + +/// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. +pub struct PushPool { + /// The sending end of a channel that accepts task activations. + sender: Sender, + + /// The receiving end of a channel that accepts task activations. + receiver: Receiver, + + /// Taskbroker configuration. + config: Arc, +} + +impl PushPool { + /// Initialize a new push pool. + pub fn new(config: Arc) -> Self { + let (sender, receiver) = flume::bounded(config.push_queue_size); + + Self { + sender, + receiver, + config, + } + } + + /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut push_pool: JoinSet> = crate::tokio::spawn_pool( + self.config.push_threads, + |_| { + let endpoint = self.config.worker_endpoint.clone(); + let receiver = self.receiver.clone(); + + let guard = get_shutdown_guard().shutdown_on_drop(); + + let callback_url = format!( + "{}:{}", + self.config.callback_addr, self.config.callback_port + ); + + let timeout = Duration::from_millis(self.config.push_timeout_ms); + + async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + + // When we fail to connect, the taskbroker will shut down, but this may change in the future + return Err(e.into()); + } + }; + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push worker received shutdown signal"); + break; + } + + message = receiver.recv_async() => { + let activation = match message { + // Received activation from fetch thread + Ok(a) => a, + + // Channel closed + Err(_) => break + }; + + let id = activation.id.clone(); + let callback_url = callback_url.clone(); + + match push_task(&mut worker, activation, callback_url, timeout).await { + Ok(_) => debug!(task_id = %id, "Activation sent to worker"), + + // Once processing deadline expires, status will be set back to pending + Err(e) => error!( + task_id = %id, + error = ?e, + "Failed to send activation to worker" + ) + }; + } + } + } + + // Drain channel before exiting + for activation in receiver.drain() { + let id = activation.id.clone(); + let callback_url = callback_url.clone(); + + match push_task(&mut worker, activation, callback_url, timeout).await { + Ok(_) => debug!(task_id = %id, "Activation sent to worker"), + + // Once processing deadline expires, status will be set back to pending + Err(e) => error!( + task_id = %id, + error = ?e, + "Failed to send activation to worker" + ), + }; + } + + Ok(()) + } + }, + ); + + while let Some(result) = push_pool.join_next().await { + match result { + Ok(r) => { + // Connection failed + r? + } + + // Join failed + Err(e) => return Err(e.into()), + } + } + + Ok(()) + } + + /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. Times out after `config.push_queue_timeout_ms` milliseconds. + pub async fn submit(&self, activation: InflightActivation) -> Result<(), PushError> { + let duration = Duration::from_millis(self.config.push_queue_timeout_ms); + + match tokio::time::timeout(duration, self.sender.send_async(activation)).await { + Ok(Ok(())) => Ok(()), + + // The channel has a problem + Ok(Err(e)) => Err(PushError::Channel(e)), + + // The channel was full so the send timed out + Err(_elapsed) => Err(PushError::Timeout), + } + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut W, + activation: InflightActivation, + callback_url: String, + timeout: Duration, +) -> Result<()> { + let start = Instant::now(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; + + let result = match tokio::time::timeout(timeout, worker.send(request)).await { + Ok(r) => r, + Err(e) => Err(e.into()), + }; + + metrics::histogram!("push.push_task.duration").record(start.elapsed()); + result +} + +#[cfg(test)] +mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs new file mode 100644 index 00000000..b2db0ba1 --- /dev/null +++ b/src/push/tests.rs @@ -0,0 +1,141 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use sentry_protos::taskbroker::v1::PushTaskRequest; +use tokio::time::{Duration, timeout}; +use tonic::async_trait; + +use super::*; +use crate::config::Config; +use crate::test_utils::make_activations; + +/// Fake worker client for unit testing. +struct MockWorkerClient { + /// Capture all received requests so we can assert things about them. + captured_requests: Vec, + + /// Should requests to the worker client fail? + should_fail: bool, +} + +impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + let captured_requests = vec![]; + + Self { + captured_requests, + should_fail, + } + } +} + +#[async_trait] +impl WorkerClient for MockWorkerClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.captured_requests.push(request); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +#[tokio::test] +async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + let callback_url = "taskbroker:50051".to_string(); + + let result = push_task( + &mut worker, + activation.clone(), + callback_url.clone(), + Duration::from_secs(5), + ) + .await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_requests.len(), 1); + + let request = &worker.captured_requests[0]; + assert_eq!(request.callback_url, callback_url); + assert_eq!( + request.task.as_ref().map(|task| task.id.as_str()), + Some(activation.id.as_str()) + ); +} + +#[tokio::test] +async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = push_task( + &mut worker, + activation, + "taskbroker:50051".to_string(), + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_requests.is_empty(), + "worker should not be called if decode fails" + ); +} + +#[tokio::test] +async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = push_task( + &mut worker, + activation, + "taskbroker:50051".to_string(), + Duration::from_secs(5), + ) + .await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_requests.len(), 1); +} + +#[tokio::test] +async fn push_pool_submit_enqueues_item() { + let config = Arc::new(Config { + push_queue_size: 2, + ..Config::default() + }); + + let pool = PushPool::new(config); + let activation = make_activations(1).remove(0); + + let result = pool.submit(activation).await; + assert!(result.is_ok(), "submit should enqueue activation"); +} + +#[tokio::test] +async fn push_pool_submit_backpressures_when_queue_full() { + let config = Arc::new(Config { + push_queue_size: 1, + ..Config::default() + }); + + let pool = PushPool::new(config); + + let first = make_activations(1).remove(0); + let second = make_activations(1).remove(0); + + pool.submit(first) + .await + .expect("first submit should fill queue"); + + let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; + assert!( + second_submit.is_err(), + "second submit should block when queue is full" + ); +} diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 9de6bf2b..4cd29b12 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -352,33 +352,33 @@ pub trait InflightActivationStore: Send + Sync { /// Store a batch of activations async fn store(&self, batch: Vec) -> Result; - /// Get a single pending activation, optionally filtered by namespace + /// Get a single pending activation, optionally filtered by namespaces. async fn get_pending_activation( &self, application: Option<&str>, - namespace: Option<&str>, + namespaces: Option<&[String]>, ) -> Result, Error> { - // Convert single namespace to vector for internal use - let namespaces = namespace.map(|ns| vec![ns.to_string()]); - - // If a namespace filter is used, an application must also be used. if namespaces.is_some() && application.is_none() { warn!( - "Received request for namespaced task without application. namespaces = {namespaces:?}" + ?namespaces, + "Received request for namespaced task without application" ); return Ok(None); } - let result = self - .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) + + let results = self + .get_pending_activations(application, namespaces, Some(1)) .await?; - if result.is_empty() { + + if results.is_empty() { return Ok(None); } - Ok(Some(result[0].clone())) + + Ok(Some(results[0].clone())) } - /// Get pending activations from specified namespaces - async fn get_pending_activations_from_namespaces( + /// Claim pending activations (moves them to processing), optionally filtered by application and namespaces. + async fn get_pending_activations( &self, application: Option<&str>, namespaces: Option<&[String]>, @@ -800,11 +800,11 @@ impl InflightActivationStore for SqliteActivationStore { meta_result } - /// Get a pending activation from specified namespaces - /// If namespaces is None, gets from any namespace - /// If namespaces is Some(&[...]), gets from those namespaces + /// Claim pending activations from specified namespaces (moves them to processing). + /// If namespaces is `None`, gets from any namespace. + /// If namespaces is `Some(...)`, restricts to those namespaces. #[instrument(skip_all)] - async fn get_pending_activations_from_namespaces( + async fn get_pending_activations( &self, application: Option<&str>, namespaces: Option<&[String]>, diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index 9113b26c..010f35ad 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -232,7 +232,7 @@ async fn test_get_pending_activation_with_race(#[case] adapter: &str) { join_set.spawn(async move { rx.recv().await.unwrap(); store - .get_pending_activation(Some("sentry"), Some(&ns)) + .get_pending_activation(Some("sentry"), Some(std::slice::from_ref(&ns))) .await .unwrap() .unwrap() @@ -263,9 +263,10 @@ async fn test_get_pending_activation_with_namespace(#[case] adapter: &str) { batch[1].namespace = "other_namespace".into(); assert!(store.store(batch.clone()).await.is_ok()); + let other_namespace = "other_namespace".to_string(); // Get activation from other namespace let result = store - .get_pending_activation(Some("sentry"), Some("other_namespace")) + .get_pending_activation(Some("sentry"), Some(std::slice::from_ref(&other_namespace))) .await .unwrap() .unwrap(); @@ -293,7 +294,7 @@ async fn test_get_pending_activation_from_multiple_namespaces(#[case] adapter: & // Get activation from multiple namespaces (should get oldest) let namespaces = vec!["ns2".to_string(), "ns3".to_string()]; let result = store - .get_pending_activations_from_namespaces(None, Some(&namespaces), None) + .get_pending_activations(None, Some(&namespaces), None) .await .unwrap(); @@ -320,8 +321,9 @@ async fn test_get_pending_activation_with_namespace_requires_application(#[case] // This is an invalid query as we don't want to allow clients // to fetch tasks from any application. + let other_namespace = "other_namespace".to_string(); let opt = store - .get_pending_activation(None, Some("other_namespace")) + .get_pending_activation(None, Some(std::slice::from_ref(&other_namespace))) .await .unwrap(); assert!(opt.is_none()); @@ -329,7 +331,7 @@ async fn test_get_pending_activation_with_namespace_requires_application(#[case] // We allow no application in this method because of usage in upkeep let namespaces = vec!["other_namespace".to_string()]; let activations = store - .get_pending_activations_from_namespaces(None, Some(namespaces).as_deref(), Some(2)) + .get_pending_activations(None, Some(&namespaces), Some(2)) .await .unwrap(); assert_eq!( @@ -479,9 +481,10 @@ async fn test_get_pending_activation_with_application_and_namespace(#[case] adap batch[2].namespace = "not-target".into(); assert!(store.store(batch.clone()).await.is_ok()); + let target_ns = "target".to_string(); // Get activation from a named application let result = store - .get_pending_activation(Some("hammers"), Some("target")) + .get_pending_activation(Some("hammers"), Some(std::slice::from_ref(&target_ns))) .await .unwrap() .unwrap(); diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index bd01acdc..8d472713 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -255,11 +255,11 @@ impl InflightActivationStore for PostgresActivationStore { Ok(query.execute(&mut *conn).await?.into()) } - /// Get a pending activation from specified namespaces - /// If namespaces is None, gets from any namespace - /// If namespaces is Some(&[...]), gets from those namespaces + /// Claim pending activations from specified namespaces (moves them to processing). + /// If namespaces is `None`, gets from any namespace. + /// If namespaces is `Some(...)`, restricts to those namespaces. #[instrument(skip_all)] - async fn get_pending_activations_from_namespaces( + async fn get_pending_activations( &self, application: Option<&str>, namespaces: Option<&[String]>, diff --git a/src/tokio.rs b/src/tokio.rs new file mode 100644 index 00000000..5452e285 --- /dev/null +++ b/src/tokio.rs @@ -0,0 +1,33 @@ +use std::future::Future; + +use tokio::task::JoinSet; + +/// Spawns `max(n, 1)` tasks, each running the future produced by `f` with the task's index. +/// Returns a [`JoinSet`] containing all spawned tasks. +pub fn spawn_pool(n: usize, f: F) -> JoinSet +where + F: Fn(usize) -> Fut, + Fut: Future + Send + 'static, + Fut::Output: Send, +{ + let mut join_set = JoinSet::new(); + + let count = n.max(1); + for i in 0..count { + join_set.spawn(f(i)); + } + + join_set +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn spawn_pool_spawns_one_worker_when_n_is_zero() { + let mut set = spawn_pool(0, |i| async move { i }); + assert_eq!(set.join_next().await.unwrap().unwrap(), 0); + assert!(set.join_next().await.is_none()); + } +} diff --git a/src/upkeep.rs b/src/upkeep.rs index 58eaf986..3c6bfeb4 100644 --- a/src/upkeep.rs +++ b/src/upkeep.rs @@ -304,7 +304,7 @@ pub async fn do_upkeep( .expect("Could not create kafka producer in upkeep"), ); if let Ok(tasks) = store - .get_pending_activations_from_namespaces(None, Some(&demoted_namespaces), None) + .get_pending_activations(None, Some(&demoted_namespaces), None) .await { // Produce tasks to Kafka with updated namespace