Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/core/src/client/message_handlers_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst
res.map_err(|e| {
(
Some(reducer),
mod_info.module_def.reducer_full(&**reducer).map(|(id, _)| id),
mod_info.module_def.reducer_by_name(&**reducer).map(|(id, _)| id),
e.into(),
)
})
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@ impl ModuleHost {
let (reducer_id, reducer_def) = self
.info
.module_def
.reducer_full(reducer_name)
.reducer_by_name(reducer_name)
.ok_or(ReducerCallError::NoSuchReducer)?;
if let Some(lifecycle) = reducer_def.lifecycle {
return Err(ReducerCallError::LifecycleReducer(lifecycle));
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/host/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ fn function_to_call_params(

// Find the function and deserialize the arguments.
let module = &module.module_def;
let (id, args) = if let Some((id, def)) = module.reducer_full(name) {
let (id, args) = if let Some((id, def)) = module.reducer_by_name(name) {
let args = args.into_tuple_for_def(module, def).map_err(InvalidReducerArguments)?;
(FunctionId::Reducer(id), args)
} else if let Some((id, def)) = module.procedure_full(name) {
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/host/wasm_common/module_host_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1453,7 +1453,7 @@ impl AllVmMetrics {
let def = &info.module_def;
let reducers = def.reducer_ids_and_defs();
let num_reducers = reducers.len() as u32;
let reducers = reducers.map(|(_, def)| def.name());
let reducers = reducers.into_iter().map(|(_, def)| def.name());

// These are the views:
let views = def.views().map(|def| def.name());
Expand Down
144 changes: 137 additions & 7 deletions crates/schema/src/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,30 @@ impl ModuleDef {
self.reducers.values()
}

/// Returns an iterator over all reducer ids and definitions.
pub fn reducer_ids_and_defs(&self) -> impl ExactSizeIterator<Item = (ReducerId, &ReducerDef)> {
self.reducers.values().enumerate().map(|(idx, def)| (idx.into(), def))
/// Returns all reducer ids and definitions in depth-first mount order.
///
/// IDs are assigned as follows: consumer's own reducers first (0..N), then each
/// mounted submodule's reducers in the order they appear in `mounts`, recursively.
pub fn reducer_ids_and_defs(&self) -> Vec<(ReducerId, &ReducerDef)> {
let mut out = Vec::with_capacity(self.reducer_count());
self.collect_reducers(0, &mut out);
out
}

/// Total reducer count including all mounted submodules (depth-first sum).
pub fn reducer_count(&self) -> usize {
self.reducers.len() + self.mounts.values().map(|m| m.reducer_count()).sum::<usize>()
}

fn collect_reducers<'a>(&'a self, offset: usize, out: &mut Vec<(ReducerId, &'a ReducerDef)>) {
for (i, def) in self.reducers.values().enumerate() {
out.push(((offset + i).into(), def));
}
let mut child_offset = offset + self.reducers.len();
for mount in self.mounts.values() {
mount.collect_reducers(child_offset, out);
child_offset += mount.reducer_count();
}
}

/// The procedures of the module definition.
Expand Down Expand Up @@ -304,14 +325,49 @@ impl ModuleDef {
self.reducers.get_full(name).map(|(idx, _, def)| (idx.into(), def))
}

/// Look up a reducer by its id.
/// Look up a reducer by its wire name, resolving qualified names like `"myauth/verify_token"`.
///
/// A plain name searches the consumer's own reducers. A slash-qualified name routes to
/// the matching mount and recurses. Nesting is supported: `"auth/baz/cleanup"`.
/// Returns the depth-first `ReducerId` and the `ReducerDef`.
pub fn reducer_by_name(&self, name: &str) -> Option<(ReducerId, &ReducerDef)> {
match name.split_once('/') {
None => self.reducer_full(name),
Some((namespace, rest)) => {
let mut offset = self.reducers.len();
for (ns, mount) in &self.mounts {
if ns == namespace {
let (inner_id, def) = mount.reducer_by_name(rest)?;
return Some(((offset + inner_id.idx()).into(), def));
}
offset += mount.reducer_count();
}
None
}
}
}

/// Look up a reducer by its depth-first id.
pub fn reducer_by_id(&self, id: ReducerId) -> &ReducerDef {
&self.reducers[id.idx()]
self.get_reducer_by_id(id)
.unwrap_or_else(|| panic!("reducer id {id:?} out of range"))
}

/// Look up a reducer by its id.
/// Look up a reducer by its depth-first id, returning `None` if it doesn't exist.
pub fn get_reducer_by_id(&self, id: ReducerId) -> Option<&ReducerDef> {
self.reducers.get_index(id.idx()).map(|(_, def)| def)
let idx = id.idx();
if idx < self.reducers.len() {
return self.reducers.get_index(idx).map(|(_, def)| def);
}
let mut offset = self.reducers.len();
for mount in self.mounts.values() {
let count = mount.reducer_count();
if idx < offset + count {
return mount.get_reducer_by_id(ReducerId::from(idx - offset));
}
offset += count;
}
None
}

/// Look up a view by its id, and whether it is anonymous.
Expand Down Expand Up @@ -2053,4 +2109,78 @@ mod tests {
.count()
== 2))
}

#[test]
fn mounted_reducer_ids_are_depth_first() {
use spacetimedb_lib::db::raw_def::v10::{
RawModuleDefV10Builder, RawModuleDefV10Section, RawModuleMountV10,
};

// baz library: 1 reducer
let mut baz_builder = RawModuleDefV10Builder::new();
baz_builder.add_reducer("baz_reduce", ProductType::unit());

// auth library: 1 own reducer, mounts baz
let mut auth_builder = RawModuleDefV10Builder::new();
auth_builder.add_reducer("auth_verify", ProductType::unit());
let mut auth_raw = auth_builder.finish();
auth_raw.sections.push(RawModuleDefV10Section::Mounts(vec![RawModuleMountV10 {
namespace: "baz".to_string(),
module: baz_builder.finish(),
}]));

// consumer: 2 own reducers, mounts auth
let mut consumer_builder = RawModuleDefV10Builder::new();
consumer_builder.add_reducer("consumer_a", ProductType::unit());
consumer_builder.add_reducer("consumer_b", ProductType::unit());
let mut consumer_raw = consumer_builder.finish();
consumer_raw.sections.push(RawModuleDefV10Section::Mounts(vec![RawModuleMountV10 {
namespace: "auth".to_string(),
module: auth_raw,
}]));

let def: ModuleDef = consumer_raw.try_into().expect("valid module");

// Total count: 2 consumer + 1 auth + 1 baz
assert_eq!(def.reducer_count(), 4);

// Depth-first order: consumer_a=0, consumer_b=1, auth_verify=2, baz_reduce=3
let ids_and_defs = def.reducer_ids_and_defs();
assert_eq!(ids_and_defs.len(), 4);
assert_eq!(ids_and_defs[0].0, ReducerId(0));
assert_eq!(&*ids_and_defs[0].1.name, "consumer_a");
assert_eq!(ids_and_defs[1].0, ReducerId(1));
assert_eq!(&*ids_and_defs[1].1.name, "consumer_b");
assert_eq!(ids_and_defs[2].0, ReducerId(2));
assert_eq!(&*ids_and_defs[2].1.name, "auth_verify");
assert_eq!(ids_and_defs[3].0, ReducerId(3));
assert_eq!(&*ids_and_defs[3].1.name, "baz_reduce");

// get_reducer_by_id resolves mounted reducer IDs correctly
assert_eq!(&*def.reducer_by_id(ReducerId(2)).name, "auth_verify");
assert_eq!(&*def.reducer_by_id(ReducerId(3)).name, "baz_reduce");
assert!(def.get_reducer_by_id(ReducerId(4)).is_none());

// reducer_by_name routes plain names to own reducers
let (id, rdef) = def.reducer_by_name("consumer_a").expect("plain name resolves");
assert_eq!(id, ReducerId(0));
assert_eq!(&*rdef.name, "consumer_a");

// reducer_by_name routes qualified names to mounted reducers
let (id, rdef) = def.reducer_by_name("auth/auth_verify").expect("qualified name resolves");
assert_eq!(id, ReducerId(2));
assert_eq!(&*rdef.name, "auth_verify");

// reducer_by_name routes deeply nested qualified names
let (id, rdef) = def
.reducer_by_name("auth/baz/baz_reduce")
.expect("nested qualified name resolves");
assert_eq!(id, ReducerId(3));
assert_eq!(&*rdef.name, "baz_reduce");

// Non-existent names return None
assert!(def.reducer_by_name("auth/nonexistent").is_none());
assert!(def.reducer_by_name("nonexistent").is_none());
assert!(def.reducer_by_name("nonamespace/auth_verify").is_none());
}
}
Loading