From 6d04e9da5c72a4eb286443867bc406fc66c71930 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Tue, 13 Jan 2026 17:02:24 -0800 Subject: [PATCH 01/29] Add famst crate: Fast Approximate Minimum Spanning Tree Implementation of the FAMST algorithm from Almansoori & Telek (2025). Features: - Generic over data type T and distance function - Uses NN-Descent for ANN graph construction - Three-phase approach: ANN graph, component connection, edge refinement - O(dn log n) time complexity, O(dn + kn) space complexity Paper: https://arxiv.org/abs/2507.14261 --- crates/famst/Cargo.toml | 12 + crates/famst/src/lib.rs | 1026 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 1038 insertions(+) create mode 100644 crates/famst/Cargo.toml create mode 100644 crates/famst/src/lib.rs diff --git a/crates/famst/Cargo.toml b/crates/famst/Cargo.toml new file mode 100644 index 0000000..da26573 --- /dev/null +++ b/crates/famst/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "famst" +version = "0.1.0" +edition = "2024" +description = "Fast Approximate Minimum Spanning Tree (FAMST) algorithm" +license = "MIT" + +[dependencies] +rand = "0.8" + +[dev-dependencies] +rand = "0.8" diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs new file mode 100644 index 0000000..6c97bf6 --- /dev/null +++ b/crates/famst/src/lib.rs @@ -0,0 +1,1026 @@ +//! FAMST: Fast Approximate Minimum Spanning Tree +//! +//! Implementation of the FAMST algorithm from: +//! "FAMST: Fast Approximate Minimum Spanning Tree Construction for Large-Scale +//! and High-Dimensional Data" (Almansoori & Telek, 2025) +//! +//! The algorithm uses three phases: +//! 1. ANN graph construction using NN-Descent +//! 2. Component analysis and connection with random edges +//! 3. Iterative edge refinement +//! +//! Generic over data type `T` and distance function. + +use rand::seq::SliceRandom; +use rand::Rng; +use std::collections::{BinaryHeap, HashMap, HashSet}; + +/// An edge in the MST, represented as (node_a, node_b, distance) +#[derive(Debug, Clone)] +pub struct Edge { + pub u: usize, + pub v: usize, + pub distance: f64, +} + +impl Edge { + pub fn new(u: usize, v: usize, distance: f64) -> Self { + Edge { u, v, distance } + } +} + +/// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm +pub struct UnionFind { + parent: Vec, + rank: Vec, +} + +impl UnionFind { + pub fn new(n: usize) -> Self { + UnionFind { + parent: (0..n).collect(), + rank: vec![0; n], + } + } + + pub fn find(&mut self, x: usize) -> usize { + if self.parent[x] != x { + self.parent[x] = self.find(self.parent[x]); // Path compression + } + self.parent[x] + } + + pub fn union(&mut self, x: usize, y: usize) -> bool { + let px = self.find(x); + let py = self.find(y); + if px == py { + return false; + } + // Union by rank + match self.rank[px].cmp(&self.rank[py]) { + std::cmp::Ordering::Less => self.parent[px] = py, + std::cmp::Ordering::Greater => self.parent[py] = px, + std::cmp::Ordering::Equal => { + self.parent[py] = px; + self.rank[px] += 1; + } + } + true + } +} + +/// Approximate Nearest Neighbors graph representation +/// Contains neighbor indices and distances for each point +pub struct AnnGraph { + /// neighbors[i] contains the indices of k nearest neighbors of point i + pub neighbors: Vec>, + /// distances[i] contains the distances to k nearest neighbors of point i + pub distances: Vec>, +} + +impl AnnGraph { + pub fn new(neighbors: Vec>, distances: Vec>) -> Self { + assert_eq!(neighbors.len(), distances.len()); + AnnGraph { + neighbors, + distances, + } + } + + pub fn n(&self) -> usize { + self.neighbors.len() + } +} + +/// FAMST algorithm configuration +pub struct FamstConfig { + /// Number of nearest neighbors (k in k-NN graph) + pub k: usize, + /// Number of random edges per component pair (λ in the paper) + pub lambda: usize, + /// Maximum refinement iterations (0 for unlimited until convergence) + pub max_iterations: usize, + /// Maximum NN-Descent iterations + pub nn_descent_iterations: usize, + /// Sample rate for NN-Descent (fraction of neighbors to sample) + pub nn_descent_sample_rate: f64, +} + +impl Default for FamstConfig { + fn default() -> Self { + FamstConfig { + k: 20, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 10, + nn_descent_sample_rate: 0.5, + } + } +} + +/// Result of FAMST algorithm +pub struct FamstResult { + /// MST edges + pub edges: Vec, + /// Total weight of the MST + pub total_weight: f64, +} + +/// Main FAMST algorithm implementation +/// +/// Generic over: +/// - `T`: The data type stored at each point +/// - `D`: Distance function `Fn(&T, &T) -> f64` +/// +/// # Arguments +/// * `data` - Slice of data points +/// * `distance_fn` - Function to compute distance between two points +/// * `config` - Algorithm configuration +/// +/// # Returns +/// The approximate MST as a list of edges +pub fn famst(data: &[T], distance_fn: D, config: &FamstConfig) -> FamstResult +where + D: Fn(&T, &T) -> f64, +{ + famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) +} + +/// FAMST with custom RNG for reproducibility +pub fn famst_with_rng( + data: &[T], + distance_fn: D, + config: &FamstConfig, + rng: &mut R, +) -> FamstResult +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let n = data.len(); + if n == 0 { + return FamstResult { + edges: vec![], + total_weight: 0.0, + }; + } + if n == 1 { + return FamstResult { + edges: vec![], + total_weight: 0.0, + }; + } + + // Phase 1: Build ANN graph using NN-Descent + let ann_graph = nn_descent(data, &distance_fn, config, rng); + + // Phase 2: Build undirected graph and find connected components + let (undirected_graph, components) = find_components(&ann_graph); + + // If only one component, skip inter-component edge logic + println!("components {}", components.len()); + if components.len() <= 1 { + let edges = extract_mst_from_ann(&ann_graph, n); + let total_weight = edges.iter().map(|e| e.distance).sum(); + return FamstResult { + edges, + total_weight, + }; + } + + // Phase 2 continued: Add random edges between components + let (mut inter_edges, edge_components) = + add_random_edges(data, &components, config.lambda, &distance_fn, rng); + + // Phase 3: Iterative edge refinement + let mut iterations = 0; + loop { + let (refined_edges, changes) = refine_edges( + data, + &undirected_graph, + &components, + &inter_edges, + &edge_components, + &distance_fn, + ); + inter_edges = refined_edges; + + if changes == 0 { + break; + } + + iterations += 1; + if config.max_iterations > 0 && iterations >= config.max_iterations { + break; + } + } + + // Phase 4: Extract MST using Kruskal's algorithm + let edges = extract_mst(&ann_graph, &inter_edges, n); + let total_weight = edges.iter().map(|e| e.distance).sum(); + + FamstResult { + edges, + total_weight, + } +} + +/// A neighbor entry in the k-NN heap (max-heap by distance for easy replacement of farthest) +#[derive(Clone, Copy)] +struct NeighborEntry { + index: usize, + distance: f64, +} + +impl PartialEq for NeighborEntry { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for NeighborEntry {} + +impl PartialOrd for NeighborEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NeighborEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Max-heap: larger distances have higher priority + self.distance + .partial_cmp(&other.distance) + .unwrap_or(std::cmp::Ordering::Equal) + } +} + +/// NN-Descent algorithm for approximate k-NN graph construction +/// +/// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" +/// by Wei Dong, Charikar Moses, and Kai Li (2011) +fn nn_descent(data: &[T], distance_fn: &D, config: &FamstConfig, rng: &mut R) -> AnnGraph +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let n = data.len(); + let k = config.k.min(n - 1); + + if k == 0 || n <= 1 { + return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); + } + + // Initialize with random neighbors using max-heap for each point + let mut heaps: Vec> = Vec::with_capacity(n); + let mut neighbor_sets: Vec> = vec![HashSet::with_capacity(k); n]; + + for i in 0..n { + let mut heap = BinaryHeap::with_capacity(k); + let mut indices: Vec = (0..n).filter(|&j| j != i).collect(); + indices.shuffle(rng); + + for &j in indices.iter().take(k) { + let d = distance_fn(&data[i], &data[j]); + heap.push(NeighborEntry { + index: j, + distance: d, + }); + neighbor_sets[i].insert(j); + } + heaps.push(heap); + } + + // Build reverse neighbor lists (who has me as a neighbor) + let build_reverse = |neighbor_sets: &[HashSet]| -> Vec> { + let mut reverse: Vec> = vec![HashSet::new(); n]; + for (i, neighbors) in neighbor_sets.iter().enumerate() { + for &j in neighbors { + reverse[j].insert(i); + } + } + reverse + }; + + // NN-Descent iterations + for _ in 0..config.nn_descent_iterations { + let mut updates = 0; + let reverse_neighbors = build_reverse(&neighbor_sets); + + // For each point, explore neighbors of neighbors + for i in 0..n { + // Collect candidates: neighbors and reverse neighbors + let mut candidates: Vec = Vec::new(); + + // Sample from forward neighbors + let forward: Vec = neighbor_sets[i].iter().copied().collect(); + let sample_size = + ((forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); + let mut sampled_forward = forward.clone(); + sampled_forward.shuffle(rng); + sampled_forward.truncate(sample_size); + + // Sample from reverse neighbors + let reverse: Vec = reverse_neighbors[i].iter().copied().collect(); + let sample_size = + ((reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); + let mut sampled_reverse = reverse.clone(); + sampled_reverse.shuffle(rng); + sampled_reverse.truncate(sample_size); + + // Neighbors of neighbors + for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { + for &nn in &neighbor_sets[neighbor] { + if nn != i && !neighbor_sets[i].contains(&nn) { + candidates.push(nn); + } + } + // Also check reverse neighbors of neighbors + for &rn in &reverse_neighbors[neighbor] { + if rn != i && !neighbor_sets[i].contains(&rn) { + candidates.push(rn); + } + } + } + + // Deduplicate candidates + candidates.sort_unstable(); + candidates.dedup(); + + // Try to improve neighbors + for c in candidates { + let d = distance_fn(&data[i], &data[c]); + + // Check if this is better than the worst current neighbor + if let Some(worst) = heaps[i].peek() { + if d < worst.distance { + // Remove worst and add new neighbor + let removed = heaps[i].pop().unwrap(); + neighbor_sets[i].remove(&removed.index); + + heaps[i].push(NeighborEntry { + index: c, + distance: d, + }); + neighbor_sets[i].insert(c); + updates += 1; + } + } + } + } + + // Early termination if no updates + if updates == 0 { + break; + } + } + + // Convert heaps to sorted neighbor lists + let mut neighbors = vec![Vec::with_capacity(k); n]; + let mut distances = vec![Vec::with_capacity(k); n]; + + for (i, heap) in heaps.into_iter().enumerate() { + let mut entries: Vec = heap.into_vec(); + entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + for entry in entries { + neighbors[i].push(entry.index); + distances[i].push(entry.distance); + } + } + + AnnGraph::new(neighbors, distances) +} + +/// Find connected components in the ANN graph using DFS +/// Returns the undirected graph adjacency list and component assignments +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { + let n = ann_graph.n(); + + // Build undirected graph from directed ANN graph + let mut graph: Vec> = vec![HashSet::new(); n]; + for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { + for &j in neighbors { + graph[i].insert(j); + graph[j].insert(i); + } + } + + // DFS to find components + let mut visited = vec![false; n]; + let mut components: Vec> = Vec::new(); + + for start in 0..n { + if visited[start] { + continue; + } + + let mut component = Vec::new(); + let mut stack = vec![start]; + + while let Some(u) = stack.pop() { + if visited[u] { + continue; + } + visited[u] = true; + component.push(u); + + for &v in &graph[u] { + if !visited[v] { + stack.push(v); + } + } + } + + components.push(component); + } + + (graph, components) +} + +/// Add random edges between components (Algorithm 3 in the paper) +fn add_random_edges( + data: &[T], + components: &[Vec], + lambda: usize, + distance_fn: &D, + rng: &mut R, +) -> (Vec, Vec<(usize, usize)>) +where + D: Fn(&T, &T) -> f64, + R: Rng, +{ + let t = components.len(); + let mut edges = Vec::new(); + let mut edge_components = Vec::new(); + + let lambda_sq = lambda * lambda; + + for i in 0..t { + for j in (i + 1)..t { + let mut candidates: Vec = Vec::with_capacity(lambda_sq); + + // Generate λ² candidate edges + for _ in 0..lambda_sq { + let u = *components[i].choose(rng).unwrap(); + let v = *components[j].choose(rng).unwrap(); + let d = distance_fn(&data[u], &data[v]); + candidates.push(Edge::new(u, v, d)); + } + + // Sort by distance and take top λ + candidates.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + for edge in candidates.into_iter().take(lambda) { + edges.push(edge); + edge_components.push((i, j)); + } + } + } + + (edges, edge_components) +} + +/// Refine inter-component edges (Algorithm 4 in the paper) +fn refine_edges( + data: &[T], + undirected_graph: &[HashSet], + components: &[Vec], + edges: &[Edge], + edge_components: &[(usize, usize)], + distance_fn: &D, +) -> (Vec, usize) +where + D: Fn(&T, &T) -> f64, +{ + // Build component membership lookup + let mut node_to_component: HashMap = HashMap::new(); + for (comp_idx, component) in components.iter().enumerate() { + for &node in component { + node_to_component.insert(node, comp_idx); + } + } + + // Build component node sets for quick lookup + let component_sets: Vec> = components + .iter() + .map(|c| c.iter().copied().collect()) + .collect(); + + let mut refined_edges = Vec::with_capacity(edges.len()); + let mut changes = 0; + + for (edge, &(ci, cj)) in edges.iter().zip(edge_components.iter()) { + let mut best_u = edge.u; + let mut best_v = edge.v; + let mut best_d = edge.distance; + + // Get neighbors of u that are in component ci + let neighbors_u: Vec = undirected_graph[edge.u] + .iter() + .filter(|&&n| component_sets[ci].contains(&n)) + .copied() + .collect(); + + // Try to find better u from neighbors + for u_prime in neighbors_u { + if u_prime == edge.v { + continue; + } + let d_prime = distance_fn(&data[u_prime], &data[best_v]); + if d_prime < best_d { + best_u = u_prime; + best_d = d_prime; + } + } + + // Get neighbors of v that are in component cj + let neighbors_v: Vec = undirected_graph[edge.v] + .iter() + .filter(|&&n| component_sets[cj].contains(&n)) + .copied() + .collect(); + + // Try to find better v from neighbors (using updated best_u) + for v_prime in neighbors_v { + if v_prime == edge.u { + continue; + } + let d_prime = distance_fn(&data[best_u], &data[v_prime]); + if d_prime < best_d { + best_v = v_prime; + best_d = d_prime; + } + } + + if best_u != edge.u || best_v != edge.v { + changes += 1; + } + + refined_edges.push(Edge::new(best_u, best_v, best_d)); + } + + (refined_edges, changes) +} + +/// Extract MST using Kruskal's algorithm on the connected ANN graph +fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec { + // Collect all edges from ANN graph + let mut all_edges: Vec = Vec::new(); + + for (i, (neighbors, distances)) in ann_graph + .neighbors + .iter() + .zip(ann_graph.distances.iter()) + .enumerate() + { + for (&j, &d) in neighbors.iter().zip(distances.iter()) { + // Only add edge once (i < j) + if i < j { + all_edges.push(Edge::new(i, j, d)); + } + } + } + + // Add edges where j < i (reverse direction in ANN graph) + for (i, (neighbors, distances)) in ann_graph + .neighbors + .iter() + .zip(ann_graph.distances.iter()) + .enumerate() + { + for (&j, &d) in neighbors.iter().zip(distances.iter()) { + if j < i { + // Check if this edge isn't already added + all_edges.push(Edge::new(j, i, d)); + } + } + } + + // Add inter-component edges + for edge in inter_edges { + all_edges.push(edge.clone()); + } + + // Deduplicate edges + let mut edge_set: HashMap<(usize, usize), f64> = HashMap::new(); + for edge in all_edges { + let key = if edge.u < edge.v { + (edge.u, edge.v) + } else { + (edge.v, edge.u) + }; + edge_set + .entry(key) + .and_modify(|d| { + if edge.distance < *d { + *d = edge.distance + } + }) + .or_insert(edge.distance); + } + + let mut edges: Vec = edge_set + .into_iter() + .map(|((u, v), d)| Edge::new(u, v, d)) + .collect(); + + // Sort edges by weight + edges.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + // Kruskal's algorithm + let mut uf = UnionFind::new(n); + let mut mst_edges = Vec::with_capacity(n - 1); + + for edge in edges { + if uf.union(edge.u, edge.v) { + mst_edges.push(edge); + if mst_edges.len() == n - 1 { + break; + } + } + } + + mst_edges +} + +/// Extract MST when graph is already connected (single component) +fn extract_mst_from_ann(ann_graph: &AnnGraph, n: usize) -> Vec { + extract_mst(ann_graph, &[], n) +} + +/// Euclidean distance for slices of f64 +pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() +} + +/// Manhattan distance for slices of f64 +pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::StdRng; + use rand::SeedableRng; + + #[test] + fn test_union_find() { + let mut uf = UnionFind::new(5); + assert!(uf.union(0, 1)); + assert!(uf.union(2, 3)); + assert!(!uf.union(0, 1)); // Already same set + assert!(uf.union(1, 2)); + assert_eq!(uf.find(0), uf.find(3)); + } + + #[test] + fn test_simple_mst() { + // Simple 2D points forming a triangle + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.5, 0.866], // Equilateral triangle + ]; + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 2); // MST has n-1 edges + } + + #[test] + fn test_line_points() { + // Points on a line + let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 4); + // Total weight should be 4.0 (1+1+1+1) + assert!((result.total_weight - 4.0).abs() < 1e-10); + } + + #[test] + fn test_disconnected_components() { + // Two clusters far apart + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.5, 0.5], + vec![100.0, 100.0], + vec![101.0, 100.0], + vec![100.5, 100.5], + ]; + + // k=1 will likely create disconnected components + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 1, + lambda: 3, + max_iterations: 10, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 5); // MST has n-1 edges + } + + #[test] + fn test_custom_distance() { + // Test with Manhattan distance + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; + + let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); + let config = FamstConfig { + k: 2, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 2); + // Manhattan distance from (0,0) to (1,1) is 2, and (1,1) to (2,2) is 2 + assert!((result.total_weight - 4.0).abs() < 1e-10); + } + + #[test] + fn test_generic_data_type() { + // Test with a custom struct + #[derive(Clone)] + struct Point3D { + x: f64, + y: f64, + z: f64, + } + + fn point_distance(a: &Point3D, b: &Point3D) -> f64 { + ((a.x - b.x).powi(2) + (a.y - b.y).powi(2) + (a.z - b.z).powi(2)).sqrt() + } + + let points = vec![ + Point3D { + x: 0.0, + y: 0.0, + z: 0.0, + }, + Point3D { + x: 1.0, + y: 0.0, + z: 0.0, + }, + Point3D { + x: 0.0, + y: 1.0, + z: 0.0, + }, + Point3D { + x: 0.0, + y: 0.0, + z: 1.0, + }, + ]; + + let config = FamstConfig { + k: 3, + ..Default::default() + }; + + let mut rng = StdRng::seed_from_u64(42); + let result = famst_with_rng(&points, point_distance, &config, &mut rng); + + assert_eq!(result.edges.len(), 3); + } + + #[test] + fn test_multiple_clusters() { + // Create 5 well-separated clusters to force multiple components with small k + // Each cluster is a tight group of points, clusters are far apart + use rand::distributions::{Distribution, Uniform}; + + let mut rng = StdRng::seed_from_u64(77777); + let noise = Uniform::new(-0.5, 0.5); + + let cluster_centers = vec![ + vec![0.0, 0.0], + vec![100.0, 0.0], + vec![0.0, 100.0], + vec![100.0, 100.0], + vec![50.0, 50.0], + ]; + + let points_per_cluster = 20; + let mut points: Vec> = Vec::new(); + + for center in &cluster_centers { + for _ in 0..points_per_cluster { + let point = vec![ + center[0] + noise.sample(&mut rng), + center[1] + noise.sample(&mut rng), + ]; + points.push(point); + } + } + + let n = points.len(); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + + // Use small k to create disconnected components + // With k=3 and 20 points per cluster spread over 5 clusters, + // each point's 3 nearest neighbors will be in its own cluster + let config = FamstConfig { + k: 3, + lambda: 5, + max_iterations: 50, + nn_descent_iterations: 20, + nn_descent_sample_rate: 1.0, // Full sampling for small dataset + }; + + let mut famst_rng = StdRng::seed_from_u64(88888); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + + // Should produce a valid MST with n-1 edges + assert_eq!(result.edges.len(), n - 1, "MST should have n-1 edges"); + + // Verify connectivity: all nodes should be reachable + let mut uf = UnionFind::new(n); + for edge in &result.edges { + uf.union(edge.u, edge.v); + } + // Check all nodes are in the same component + let root = uf.find(0); + for i in 1..n { + assert_eq!(uf.find(i), root, "All nodes should be connected in the MST"); + } + + // Compare with exact MST + let exact_weight = exact_mst_weight(&points, distance); + let error_ratio = (result.total_weight - exact_weight) / exact_weight; + + println!( + "Multi-cluster test: Exact MST weight: {:.4}, FAMST weight: {:.4}, error: {:.2}%", + exact_weight, + result.total_weight, + error_ratio * 100.0 + ); + + // Should be reasonably close (within 15% given the challenging setup) + assert!( + error_ratio < 0.15, + "FAMST error should be < 15%, got {:.2}%", + error_ratio * 100.0 + ); + } + + /// Compute exact MST using Kruskal's algorithm on complete graph + fn exact_mst_weight(data: &[T], distance_fn: D) -> f64 + where + D: Fn(&T, &T) -> f64, + { + let n = data.len(); + if n <= 1 { + return 0.0; + } + + // Build all edges + let mut edges: Vec<(usize, usize, f64)> = Vec::with_capacity(n * (n - 1) / 2); + for i in 0..n { + for j in (i + 1)..n { + let d = distance_fn(&data[i], &data[j]); + edges.push((i, j, d)); + } + } + + // Sort by weight + edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap()); + + // Kruskal's algorithm + let mut uf = UnionFind::new(n); + let mut total_weight = 0.0; + let mut edge_count = 0; + + for (u, v, w) in edges { + if uf.union(u, v) { + total_weight += w; + edge_count += 1; + if edge_count == n - 1 { + break; + } + } + } + + total_weight + } + + #[test] + #[ignore] // Run with: cargo test large_scale -- --ignored --nocapture + fn test_large_scale_vs_exact() { + use rand::distributions::{Distribution, Uniform}; + + const N: usize = 1_000_000; + const DIM: usize = 10; + + println!("Generating {} random {}-dimensional points...", N, DIM); + let mut rng = StdRng::seed_from_u64(12345); + let dist = Uniform::new(0.0, 1000.0); + + let points: Vec> = (0..N) + .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) + .collect(); + + println!("Running FAMST with NN-Descent..."); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig { + k: 20, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 10, + nn_descent_sample_rate: 0.5, + }; + let mut famst_rng = StdRng::seed_from_u64(54321); + let start = std::time::Instant::now(); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + let famst_time = start.elapsed(); + + println!("FAMST completed in {:?}", famst_time); + println!("FAMST MST weight: {:.4}", result.total_weight); + println!("FAMST MST edges: {}", result.edges.len()); + + assert_eq!(result.edges.len(), N - 1, "MST should have n-1 edges"); + } + + #[test] + fn test_medium_scale_vs_exact() { + use rand::distributions::{Distribution, Uniform}; + + const N: usize = 5000; + const DIM: usize = 5; + + let mut rng = StdRng::seed_from_u64(99999); + let dist = Uniform::new(0.0, 100.0); + + let points: Vec> = (0..N) + .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) + .collect(); + + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + + // Compute exact MST + let exact_weight = exact_mst_weight(&points, distance); + + // Compute approximate MST with FAMST + let config = FamstConfig { + k: 15, + lambda: 5, + max_iterations: 100, + nn_descent_iterations: 15, + nn_descent_sample_rate: 0.5, + }; + let mut famst_rng = StdRng::seed_from_u64(11111); + let result = famst_with_rng(&points, distance, &config, &mut famst_rng); + + assert_eq!(result.edges.len(), N - 1); + + // FAMST should produce a weight >= exact (it's an approximation) + // and should be reasonably close (within a few percent for good k) + let error_ratio = (result.total_weight - exact_weight) / exact_weight; + println!( + "Exact MST weight: {:.4}, FAMST weight: {:.4}, error: {:.2}%", + exact_weight, + result.total_weight, + error_ratio * 100.0 + ); + + // The approximation should be within 10% for this setup with NN-Descent + assert!( + error_ratio >= 0.0, + "FAMST weight should be >= exact MST weight" + ); + assert!( + error_ratio < 0.10, + "FAMST error should be < 10%, got {:.2}%", + error_ratio * 100.0 + ); + } +} From c09952fffb25c7063a0774ba62117e2f73bdcf5e Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 14:52:08 -0800 Subject: [PATCH 02/29] =?UTF-8?q?Fix=20O(n=C2=B2)=20initialization=20bug?= =?UTF-8?q?=20in=20NN-Descent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use Floyd's algorithm to sample k random neighbors in O(k) time instead of allocating and shuffling all n indices per point. --- crates/famst/src/lib.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 6c97bf6..296e8dd 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -277,16 +277,33 @@ where for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); - let mut indices: Vec = (0..n).filter(|&j| j != i).collect(); - indices.shuffle(rng); - - for &j in indices.iter().take(k) { - let d = distance_fn(&data[i], &data[j]); - heap.push(NeighborEntry { - index: j, - distance: d, - }); - neighbor_sets[i].insert(j); + + // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) + // https://fermatslibrary.com/s/a-sample-of-brilliance + // This selects k distinct elements from 0..n, excluding i + let effective_n = n - 1; // exclude self + let range_start = effective_n.saturating_sub(k); + for t in range_start..effective_n { + let j = rng.gen_range(0..=t); + // Map j to actual index, skipping i + let actual_j = if j >= i { j + 1 } else { j }; + + if !neighbor_sets[i].insert(actual_j) { + // j was already selected, so add t instead + let actual_t = if t >= i { t + 1 } else { t }; + neighbor_sets[i].insert(actual_t); + let d = distance_fn(&data[i], &data[actual_t]); + heap.push(NeighborEntry { + index: actual_t, + distance: d, + }); + } else { + let d = distance_fn(&data[i], &data[actual_j]); + heap.push(NeighborEntry { + index: actual_j, + distance: d, + }); + } } heaps.push(heap); } From 4ad3eb2033c1e71476960097498cd1a5ad2ae521 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:34:00 -0800 Subject: [PATCH 03/29] Unify duplicate loops in extract_mst --- crates/famst/src/lib.rs | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 296e8dd..71fb291 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -592,25 +592,8 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec Date: Wed, 14 Jan 2026 15:36:30 -0800 Subject: [PATCH 04/29] Remove unnecessary edge deduplication in extract_mst Kruskal's algorithm naturally skips duplicate edges via union-find. --- crates/famst/src/lib.rs | 34 ++++------------------------------ 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 71fb291..f46349a 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -583,7 +583,7 @@ where /// Extract MST using Kruskal's algorithm on the connected ANN graph fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec { // Collect all edges from ANN graph - let mut all_edges: Vec = Vec::new(); + let mut edges: Vec = Vec::new(); for (i, (neighbors, distances)) in ann_graph .neighbors @@ -592,43 +592,17 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec = HashMap::new(); - for edge in all_edges { - let key = if edge.u < edge.v { - (edge.u, edge.v) - } else { - (edge.v, edge.u) - }; - edge_set - .entry(key) - .and_modify(|d| { - if edge.distance < *d { - *d = edge.distance - } - }) - .or_insert(edge.distance); - } - - let mut edges: Vec = edge_set - .into_iter() - .map(|((u, v), d)| Edge::new(u, v, d)) - .collect(); + edges.extend(inter_edges.iter().cloned()); // Sort edges by weight edges.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - // Kruskal's algorithm + // Kruskal's algorithm (naturally handles duplicate edges) let mut uf = UnionFind::new(n); let mut mst_edges = Vec::with_capacity(n - 1); From c14f0cb9a8072506c3b7303bbe74680b3d10f374 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:38:46 -0800 Subject: [PATCH 05/29] Unify n==0 and n==1 cases, remove debug println --- crates/famst/src/lib.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index f46349a..8d06bbe 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -158,13 +158,7 @@ where R: Rng, { let n = data.len(); - if n == 0 { - return FamstResult { - edges: vec![], - total_weight: 0.0, - }; - } - if n == 1 { + if n <= 1 { return FamstResult { edges: vec![], total_weight: 0.0, @@ -178,7 +172,6 @@ where let (undirected_graph, components) = find_components(&ann_graph); // If only one component, skip inter-component edge logic - println!("components {}", components.len()); if components.len() <= 1 { let edges = extract_mst_from_ann(&ann_graph, n); let total_weight = edges.iter().map(|e| e.distance).sum(); From 4ccefd923a4cb45f69ddff9b1d68051b4f959e8c Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:39:52 -0800 Subject: [PATCH 06/29] Add tests for empty and single-point inputs --- crates/famst/src/lib.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 8d06bbe..b5a0209 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -636,6 +636,24 @@ mod tests { use rand::rngs::StdRng; use rand::SeedableRng; + #[test] + fn test_empty_input() { + let points: Vec> = vec![]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let result = famst(&points, distance, &FamstConfig::default()); + assert_eq!(result.edges.len(), 0); + assert_eq!(result.total_weight, 0.0); + } + + #[test] + fn test_single_point() { + let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let result = famst(&points, distance, &FamstConfig::default()); + assert_eq!(result.edges.len(), 0); + assert_eq!(result.total_weight, 0.0); + } + #[test] fn test_union_find() { let mut uf = UnionFind::new(5); From 60d85f226ba97fbcff00eef02748b25d709a2c82 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 14 Jan 2026 15:41:35 -0800 Subject: [PATCH 07/29] Add test for k >= n case --- crates/famst/src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b5a0209..932aa4d 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -654,6 +654,20 @@ mod tests { assert_eq!(result.total_weight, 0.0); } + #[test] + fn test_k_greater_than_n() { + // 3 points but k=20 (default), so k >= n + let points: Vec> = vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.0, 1.0], + ]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let config = FamstConfig::default(); // k=20 > n=3 + let result = famst(&points, distance, &config); + assert_eq!(result.edges.len(), 2); // MST has n-1 edges + } + #[test] fn test_union_find() { let mut uf = UnionFind::new(5); From e37f359800641ecbc56552d8e0392d1caf6c41a9 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 11:06:47 -0600 Subject: [PATCH 08/29] Manual tweaks. --- crates/famst/src/lib.rs | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 932aa4d..b147a04 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -146,7 +146,7 @@ where famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) } -/// FAMST with custom RNG for reproducibility +/// FAMST with custom RNG. (We use a seeded RNG in tests for reproducibility.) pub fn famst_with_rng( data: &[T], distance_fn: D, @@ -616,26 +616,26 @@ fn extract_mst_from_ann(ann_graph: &AnnGraph, n: usize) -> Vec { extract_mst(ann_graph, &[], n) } -/// Euclidean distance for slices of f64 -pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { - a.iter() - .zip(b.iter()) - .map(|(x, y)| (x - y).powi(2)) - .sum::() - .sqrt() -} - -/// Manhattan distance for slices of f64 -pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { - a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() -} - #[cfg(test)] mod tests { use super::*; use rand::rngs::StdRng; use rand::SeedableRng; + /// Manhattan distance for slices of f64 + pub fn manhattan_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() + } + + /// Euclidean distance for slices of f64 + pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() + } + #[test] fn test_empty_input() { let points: Vec> = vec![]; @@ -657,11 +657,7 @@ mod tests { #[test] fn test_k_greater_than_n() { // 3 points but k=20 (default), so k >= n - let points: Vec> = vec![ - vec![0.0, 0.0], - vec![1.0, 0.0], - vec![0.0, 1.0], - ]; + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig::default(); // k=20 > n=3 let result = famst(&points, distance, &config); From b3a257b61d5adbe602af1bf40da64a922436ea3a Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 11:32:36 -0600 Subject: [PATCH 09/29] Replace HashSets with sorted Vecs for memory efficiency This reduces memory usage by ~5x for large graphs: - neighbor_lists in nn_descent: HashSet -> sorted Vec - reverse_neighbors in nn_descent: HashSet -> Vec (already sorted by construction) - graph in find_components: HashSet -> sorted Vec - node_to_component in refine_edges: HashMap -> Vec (O(1) indexed lookup) - Removed component_sets HashSets entirely For n=1 billion, k=20, this saves ~2.5 TB of memory. --- crates/famst/src/lib.rs | 142 +++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 66 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b147a04..16f0f65 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -13,7 +13,7 @@ use rand::seq::SliceRandom; use rand::Rng; -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::BinaryHeap; /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] @@ -264,9 +264,33 @@ where return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); } + // Helper: check if sorted vec contains value + fn sorted_contains(v: &[usize], x: usize) -> bool { + v.binary_search(&x).is_ok() + } + + // Helper: insert into sorted vec, returns true if inserted (was not present) + fn sorted_insert(v: &mut Vec, x: usize) -> bool { + match v.binary_search(&x) { + Ok(_) => false, + Err(pos) => { + v.insert(pos, x); + true + } + } + } + + // Helper: remove from sorted vec + fn sorted_remove(v: &mut Vec, x: usize) { + if let Ok(pos) = v.binary_search(&x) { + v.remove(pos); + } + } + // Initialize with random neighbors using max-heap for each point + // neighbor_lists[i] is kept sorted by index for O(log k) membership tests let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_sets: Vec> = vec![HashSet::with_capacity(k); n]; + let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); @@ -281,10 +305,10 @@ where // Map j to actual index, skipping i let actual_j = if j >= i { j + 1 } else { j }; - if !neighbor_sets[i].insert(actual_j) { + if !sorted_insert(&mut neighbor_lists[i], actual_j) { // j was already selected, so add t instead let actual_t = if t >= i { t + 1 } else { t }; - neighbor_sets[i].insert(actual_t); + sorted_insert(&mut neighbor_lists[i], actual_t); let d = distance_fn(&data[i], &data[actual_t]); heap.push(NeighborEntry { index: actual_t, @@ -302,20 +326,22 @@ where } // Build reverse neighbor lists (who has me as a neighbor) - let build_reverse = |neighbor_sets: &[HashSet]| -> Vec> { - let mut reverse: Vec> = vec![HashSet::new(); n]; - for (i, neighbors) in neighbor_sets.iter().enumerate() { + // Returns sorted vecs for each point + let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; + for (i, neighbors) in neighbor_lists.iter().enumerate() { for &j in neighbors { - reverse[j].insert(i); + reverse[j].push(i); } } + // Sort each reverse list (they're built in order of i, so already sorted) reverse }; // NN-Descent iterations for _ in 0..config.nn_descent_iterations { let mut updates = 0; - let reverse_neighbors = build_reverse(&neighbor_sets); + let reverse_neighbors = build_reverse(&neighbor_lists); // For each point, explore neighbors of neighbors for i in 0..n { @@ -323,31 +349,29 @@ where let mut candidates: Vec = Vec::new(); // Sample from forward neighbors - let forward: Vec = neighbor_sets[i].iter().copied().collect(); + let mut sampled_forward = neighbor_lists[i].clone(); let sample_size = - ((forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - let mut sampled_forward = forward.clone(); + ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); sampled_forward.shuffle(rng); sampled_forward.truncate(sample_size); // Sample from reverse neighbors - let reverse: Vec = reverse_neighbors[i].iter().copied().collect(); + let mut sampled_reverse = reverse_neighbors[i].clone(); let sample_size = - ((reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - let mut sampled_reverse = reverse.clone(); + ((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); sampled_reverse.shuffle(rng); sampled_reverse.truncate(sample_size); // Neighbors of neighbors for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_sets[neighbor] { - if nn != i && !neighbor_sets[i].contains(&nn) { + for &nn in &neighbor_lists[neighbor] { + if nn != i && !sorted_contains(&neighbor_lists[i], nn) { candidates.push(nn); } } // Also check reverse neighbors of neighbors for &rn in &reverse_neighbors[neighbor] { - if rn != i && !neighbor_sets[i].contains(&rn) { + if rn != i && !sorted_contains(&neighbor_lists[i], rn) { candidates.push(rn); } } @@ -366,13 +390,13 @@ where if d < worst.distance { // Remove worst and add new neighbor let removed = heaps[i].pop().unwrap(); - neighbor_sets[i].remove(&removed.index); + sorted_remove(&mut neighbor_lists[i], removed.index); heaps[i].push(NeighborEntry { index: c, distance: d, }); - neighbor_sets[i].insert(c); + sorted_insert(&mut neighbor_lists[i], c); updates += 1; } } @@ -403,18 +427,23 @@ where } /// Find connected components in the ANN graph using DFS -/// Returns the undirected graph adjacency list and component assignments -fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { +/// Returns the undirected graph adjacency list (sorted vecs) and component assignments +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { let n = ann_graph.n(); - // Build undirected graph from directed ANN graph - let mut graph: Vec> = vec![HashSet::new(); n]; + // Build undirected graph from directed ANN graph using sorted vecs + let mut graph: Vec> = vec![Vec::new(); n]; for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { for &j in neighbors { - graph[i].insert(j); - graph[j].insert(i); + graph[i].push(j); + graph[j].push(i); } } + // Sort and deduplicate each adjacency list + for adj in &mut graph { + adj.sort_unstable(); + adj.dedup(); + } // DFS to find components let mut visited = vec![false; n]; @@ -494,7 +523,7 @@ where /// Refine inter-component edges (Algorithm 4 in the paper) fn refine_edges( data: &[T], - undirected_graph: &[HashSet], + undirected_graph: &[Vec], components: &[Vec], edges: &[Edge], edge_components: &[(usize, usize)], @@ -503,20 +532,16 @@ fn refine_edges( where D: Fn(&T, &T) -> f64, { - // Build component membership lookup - let mut node_to_component: HashMap = HashMap::new(); + let n = data.len(); + + // Build component membership lookup (simple vec, O(1) lookup) + let mut node_to_component: Vec = vec![0; n]; for (comp_idx, component) in components.iter().enumerate() { for &node in component { - node_to_component.insert(node, comp_idx); + node_to_component[node] = comp_idx; } } - // Build component node sets for quick lookup - let component_sets: Vec> = components - .iter() - .map(|c| c.iter().copied().collect()) - .collect(); - let mut refined_edges = Vec::with_capacity(edges.len()); let mut changes = 0; @@ -526,40 +551,25 @@ where let mut best_d = edge.distance; // Get neighbors of u that are in component ci - let neighbors_u: Vec = undirected_graph[edge.u] - .iter() - .filter(|&&n| component_sets[ci].contains(&n)) - .copied() - .collect(); - - // Try to find better u from neighbors - for u_prime in neighbors_u { - if u_prime == edge.v { - continue; - } - let d_prime = distance_fn(&data[u_prime], &data[best_v]); - if d_prime < best_d { - best_u = u_prime; - best_d = d_prime; + // (undirected_graph is sorted, so we can iterate directly) + for &u_prime in &undirected_graph[edge.u] { + if node_to_component[u_prime] == ci && u_prime != edge.v { + let d_prime = distance_fn(&data[u_prime], &data[best_v]); + if d_prime < best_d { + best_u = u_prime; + best_d = d_prime; + } } } // Get neighbors of v that are in component cj - let neighbors_v: Vec = undirected_graph[edge.v] - .iter() - .filter(|&&n| component_sets[cj].contains(&n)) - .copied() - .collect(); - - // Try to find better v from neighbors (using updated best_u) - for v_prime in neighbors_v { - if v_prime == edge.u { - continue; - } - let d_prime = distance_fn(&data[best_u], &data[v_prime]); - if d_prime < best_d { - best_v = v_prime; - best_d = d_prime; + for &v_prime in &undirected_graph[edge.v] { + if node_to_component[v_prime] == cj && v_prime != edge.u { + let d_prime = distance_fn(&data[best_u], &data[v_prime]); + if d_prime < best_d { + best_v = v_prime; + best_d = d_prime; + } } } From 12c98c9cf36d58a05bd7de17bcbd19479993206d Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 12:13:00 -0600 Subject: [PATCH 10/29] Refactor AnnGraph to use single flat allocation - Store neighbors as flat Vec instead of Vec> + Vec> - Neighbor struct combines index and distance (reusing existing NeighborEntry) - Access via ann_graph.neighbors(i) returns &[Neighbor] slice - Eliminates n pointer indirections and reduces memory fragmentation --- crates/famst/src/lib.rs | 95 +++++++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 16f0f65..335ad06 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -70,25 +70,34 @@ impl UnionFind { } /// Approximate Nearest Neighbors graph representation -/// Contains neighbor indices and distances for each point +/// Stored as a flat n×k matrix of Neighbor entries pub struct AnnGraph { - /// neighbors[i] contains the indices of k nearest neighbors of point i - pub neighbors: Vec>, - /// distances[i] contains the distances to k nearest neighbors of point i - pub distances: Vec>, + /// Flat storage: data[i*k..(i+1)*k] contains k neighbors of point i + data: Vec, + /// Number of points + n: usize, + /// Number of neighbors per point + k: usize, } impl AnnGraph { - pub fn new(neighbors: Vec>, distances: Vec>) -> Self { - assert_eq!(neighbors.len(), distances.len()); - AnnGraph { - neighbors, - distances, - } + pub fn new(n: usize, k: usize, data: Vec) -> Self { + assert_eq!(data.len(), n * k); + AnnGraph { data, n, k } } pub fn n(&self) -> usize { - self.neighbors.len() + self.n + } + + pub fn k(&self) -> usize { + self.k + } + + /// Get the neighbors of point i + pub fn neighbors(&self, i: usize) -> &[Neighbor] { + let start = i * self.k; + &self.data[start..start + self.k] } } @@ -218,28 +227,29 @@ where } } -/// A neighbor entry in the k-NN heap (max-heap by distance for easy replacement of farthest) -#[derive(Clone, Copy)] -struct NeighborEntry { - index: usize, - distance: f64, +/// A neighbor entry: node index and distance +/// Used both in the k-NN heap and in the final AnnGraph +#[derive(Debug, Clone, Copy)] +pub struct Neighbor { + pub index: usize, + pub distance: f64, } -impl PartialEq for NeighborEntry { +impl PartialEq for Neighbor { fn eq(&self, other: &Self) -> bool { self.distance == other.distance } } -impl Eq for NeighborEntry {} +impl Eq for Neighbor {} -impl PartialOrd for NeighborEntry { +impl PartialOrd for Neighbor { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for NeighborEntry { +impl Ord for Neighbor { fn cmp(&self, other: &Self) -> std::cmp::Ordering { // Max-heap: larger distances have higher priority self.distance @@ -261,7 +271,7 @@ where let k = config.k.min(n - 1); if k == 0 || n <= 1 { - return AnnGraph::new(vec![vec![]; n], vec![vec![]; n]); + return AnnGraph::new(n, 0, vec![]); } // Helper: check if sorted vec contains value @@ -289,7 +299,7 @@ where // Initialize with random neighbors using max-heap for each point // neighbor_lists[i] is kept sorted by index for O(log k) membership tests - let mut heaps: Vec> = Vec::with_capacity(n); + let mut heaps: Vec> = Vec::with_capacity(n); let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { @@ -310,13 +320,13 @@ where let actual_t = if t >= i { t + 1 } else { t }; sorted_insert(&mut neighbor_lists[i], actual_t); let d = distance_fn(&data[i], &data[actual_t]); - heap.push(NeighborEntry { + heap.push(Neighbor { index: actual_t, distance: d, }); } else { let d = distance_fn(&data[i], &data[actual_j]); - heap.push(NeighborEntry { + heap.push(Neighbor { index: actual_j, distance: d, }); @@ -392,7 +402,7 @@ where let removed = heaps[i].pop().unwrap(); sorted_remove(&mut neighbor_lists[i], removed.index); - heaps[i].push(NeighborEntry { + heaps[i].push(Neighbor { index: c, distance: d, }); @@ -409,21 +419,16 @@ where } } - // Convert heaps to sorted neighbor lists - let mut neighbors = vec![Vec::with_capacity(k); n]; - let mut distances = vec![Vec::with_capacity(k); n]; + // Convert heaps to flat neighbor array sorted by distance + let mut result_data = Vec::with_capacity(n * k); - for (i, heap) in heaps.into_iter().enumerate() { - let mut entries: Vec = heap.into_vec(); + for heap in heaps { + let mut entries: Vec = heap.into_vec(); entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - - for entry in entries { - neighbors[i].push(entry.index); - distances[i].push(entry.distance); - } + result_data.extend(entries); } - AnnGraph::new(neighbors, distances) + AnnGraph::new(n, k, result_data) } /// Find connected components in the ANN graph using DFS @@ -433,8 +438,9 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { // Build undirected graph from directed ANN graph using sorted vecs let mut graph: Vec> = vec![Vec::new(); n]; - for (i, neighbors) in ann_graph.neighbors.iter().enumerate() { - for &j in neighbors { + for i in 0..n { + for neighbor in ann_graph.neighbors(i) { + let j = neighbor.index; graph[i].push(j); graph[j].push(i); } @@ -588,14 +594,9 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec = Vec::new(); - for (i, (neighbors, distances)) in ann_graph - .neighbors - .iter() - .zip(ann_graph.distances.iter()) - .enumerate() - { - for (&j, &d) in neighbors.iter().zip(distances.iter()) { - edges.push(Edge::new(i, j, d)); + for i in 0..n { + for neighbor in ann_graph.neighbors(i) { + edges.push(Edge::new(i, neighbor.index, neighbor.distance)); } } From 2561f992e3e750a655833b80f1eb8351b857cf69 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 12:20:22 -0600 Subject: [PATCH 11/29] Make internal types private Only expose the public API: famst, famst_with_rng, FamstConfig, FamstResult, Edge. UnionFind, AnnGraph, and Neighbor are now private implementation details. --- crates/famst/src/lib.rs | 79 ++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 335ad06..46e437d 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -30,27 +30,27 @@ impl Edge { } /// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm -pub struct UnionFind { +struct UnionFind { parent: Vec, rank: Vec, } impl UnionFind { - pub fn new(n: usize) -> Self { + fn new(n: usize) -> Self { UnionFind { parent: (0..n).collect(), rank: vec![0; n], } } - pub fn find(&mut self, x: usize) -> usize { + fn find(&mut self, x: usize) -> usize { if self.parent[x] != x { self.parent[x] = self.find(self.parent[x]); // Path compression } self.parent[x] } - pub fn union(&mut self, x: usize, y: usize) -> bool { + fn union(&mut self, x: usize, y: usize) -> bool { let px = self.find(x); let py = self.find(y); if px == py { @@ -69,9 +69,39 @@ impl UnionFind { } } +/// A neighbor entry: node index and distance +#[derive(Debug, Clone, Copy)] +struct Neighbor { + index: usize, + distance: f64, +} + +impl PartialEq for Neighbor { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for Neighbor {} + +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Max-heap: larger distances have higher priority + self.distance + .partial_cmp(&other.distance) + .unwrap_or(std::cmp::Ordering::Equal) + } +} + /// Approximate Nearest Neighbors graph representation /// Stored as a flat n×k matrix of Neighbor entries -pub struct AnnGraph { +struct AnnGraph { /// Flat storage: data[i*k..(i+1)*k] contains k neighbors of point i data: Vec, /// Number of points @@ -81,21 +111,21 @@ pub struct AnnGraph { } impl AnnGraph { - pub fn new(n: usize, k: usize, data: Vec) -> Self { + fn new(n: usize, k: usize, data: Vec) -> Self { assert_eq!(data.len(), n * k); AnnGraph { data, n, k } } - pub fn n(&self) -> usize { + fn n(&self) -> usize { self.n } - pub fn k(&self) -> usize { + fn k(&self) -> usize { self.k } /// Get the neighbors of point i - pub fn neighbors(&self, i: usize) -> &[Neighbor] { + fn neighbors(&self, i: usize) -> &[Neighbor] { let start = i * self.k; &self.data[start..start + self.k] } @@ -227,37 +257,6 @@ where } } -/// A neighbor entry: node index and distance -/// Used both in the k-NN heap and in the final AnnGraph -#[derive(Debug, Clone, Copy)] -pub struct Neighbor { - pub index: usize, - pub distance: f64, -} - -impl PartialEq for Neighbor { - fn eq(&self, other: &Self) -> bool { - self.distance == other.distance - } -} - -impl Eq for Neighbor {} - -impl PartialOrd for Neighbor { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Neighbor { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Max-heap: larger distances have higher priority - self.distance - .partial_cmp(&other.distance) - .unwrap_or(std::cmp::Ordering::Equal) - } -} - /// NN-Descent algorithm for approximate k-NN graph construction /// /// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" From 4ec00638c315fa16bd26e783c17268281efdbbd4 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Sat, 17 Jan 2026 13:19:35 -0600 Subject: [PATCH 12/29] Switch to 32-bit types for memory efficiency - Use NodeId (u32) type alias for node indices - Use f32 for distances instead of f64 - Add assertion that data.len() <= 2^32 with documented panic - This halves memory usage for internal data structures For n=1 billion, k=20, this saves ~120 GB of memory in the AnnGraph alone. --- crates/famst/src/lib.rs | 175 +++++++++++++++++++++------------------- 1 file changed, 94 insertions(+), 81 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 46e437d..a29a014 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -20,11 +20,11 @@ use std::collections::BinaryHeap; pub struct Edge { pub u: usize, pub v: usize, - pub distance: f64, + pub distance: f32, } impl Edge { - pub fn new(u: usize, v: usize, distance: f64) -> Self { + fn new(u: usize, v: usize, distance: f32) -> Self { Edge { u, v, distance } } } @@ -69,11 +69,14 @@ impl UnionFind { } } -/// A neighbor entry: node index and distance +/// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) +type NodeId = u32; + +/// A neighbor entry: node index and distance (32-bit for memory efficiency) #[derive(Debug, Clone, Copy)] struct Neighbor { - index: usize, - distance: f64, + index: NodeId, + distance: f32, } impl PartialEq for Neighbor { @@ -120,10 +123,6 @@ impl AnnGraph { self.n } - fn k(&self) -> usize { - self.k - } - /// Get the neighbors of point i fn neighbors(&self, i: usize) -> &[Neighbor] { let start = i * self.k; @@ -162,14 +161,14 @@ pub struct FamstResult { /// MST edges pub edges: Vec, /// Total weight of the MST - pub total_weight: f64, + pub total_weight: f32, } /// Main FAMST algorithm implementation /// /// Generic over: /// - `T`: The data type stored at each point -/// - `D`: Distance function `Fn(&T, &T) -> f64` +/// - `D`: Distance function `Fn(&T, &T) -> f32` /// /// # Arguments /// * `data` - Slice of data points @@ -178,14 +177,20 @@ pub struct FamstResult { /// /// # Returns /// The approximate MST as a list of edges +/// +/// # Panics +/// Panics if `data.len() >= 2^32` (more than ~4 billion points). pub fn famst(data: &[T], distance_fn: D, config: &FamstConfig) -> FamstResult where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) } /// FAMST with custom RNG. (We use a seeded RNG in tests for reproducibility.) +/// +/// # Panics +/// Panics if `data.len() >= 2^32` (more than ~4 billion points). pub fn famst_with_rng( data: &[T], distance_fn: D, @@ -193,10 +198,15 @@ pub fn famst_with_rng( rng: &mut R, ) -> FamstResult where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let n = data.len(); + assert!( + n <= NodeId::MAX as usize, + "famst: data length {n} exceeds maximum supported size of 2^32" + ); + if n <= 1 { return FamstResult { edges: vec![], @@ -263,7 +273,7 @@ where /// by Wei Dong, Charikar Moses, and Kai Li (2011) fn nn_descent(data: &[T], distance_fn: &D, config: &FamstConfig, rng: &mut R) -> AnnGraph where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let n = data.len(); @@ -274,12 +284,12 @@ where } // Helper: check if sorted vec contains value - fn sorted_contains(v: &[usize], x: usize) -> bool { + fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { v.binary_search(&x).is_ok() } // Helper: insert into sorted vec, returns true if inserted (was not present) - fn sorted_insert(v: &mut Vec, x: usize) -> bool { + fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { match v.binary_search(&x) { Ok(_) => false, Err(pos) => { @@ -290,7 +300,7 @@ where } // Helper: remove from sorted vec - fn sorted_remove(v: &mut Vec, x: usize) { + fn sorted_remove(v: &mut Vec, x: NodeId) { if let Ok(pos) = v.binary_search(&x) { v.remove(pos); } @@ -299,7 +309,7 @@ where // Initialize with random neighbors using max-heap for each point // neighbor_lists[i] is kept sorted by index for O(log k) membership tests let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; + let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; for i in 0..n { let mut heap = BinaryHeap::with_capacity(k); @@ -312,19 +322,19 @@ where for t in range_start..effective_n { let j = rng.gen_range(0..=t); // Map j to actual index, skipping i - let actual_j = if j >= i { j + 1 } else { j }; + let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; if !sorted_insert(&mut neighbor_lists[i], actual_j) { // j was already selected, so add t instead - let actual_t = if t >= i { t + 1 } else { t }; + let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; sorted_insert(&mut neighbor_lists[i], actual_t); - let d = distance_fn(&data[i], &data[actual_t]); + let d = distance_fn(&data[i], &data[actual_t as usize]); heap.push(Neighbor { index: actual_t, distance: d, }); } else { - let d = distance_fn(&data[i], &data[actual_j]); + let d = distance_fn(&data[i], &data[actual_j as usize]); heap.push(Neighbor { index: actual_j, distance: d, @@ -336,11 +346,11 @@ where // Build reverse neighbor lists (who has me as a neighbor) // Returns sorted vecs for each point - let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { - let mut reverse: Vec> = vec![Vec::new(); n]; + let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; for (i, neighbors) in neighbor_lists.iter().enumerate() { for &j in neighbors { - reverse[j].push(i); + reverse[j as usize].push(i as NodeId); } } // Sort each reverse list (they're built in order of i, so already sorted) @@ -355,7 +365,7 @@ where // For each point, explore neighbors of neighbors for i in 0..n { // Collect candidates: neighbors and reverse neighbors - let mut candidates: Vec = Vec::new(); + let mut candidates: Vec = Vec::new(); // Sample from forward neighbors let mut sampled_forward = neighbor_lists[i].clone(); @@ -372,15 +382,16 @@ where sampled_reverse.truncate(sample_size); // Neighbors of neighbors + let i_id = i as NodeId; for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_lists[neighbor] { - if nn != i && !sorted_contains(&neighbor_lists[i], nn) { + for &nn in &neighbor_lists[neighbor as usize] { + if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) { candidates.push(nn); } } // Also check reverse neighbors of neighbors - for &rn in &reverse_neighbors[neighbor] { - if rn != i && !sorted_contains(&neighbor_lists[i], rn) { + for &rn in &reverse_neighbors[neighbor as usize] { + if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) { candidates.push(rn); } } @@ -392,7 +403,7 @@ where // Try to improve neighbors for c in candidates { - let d = distance_fn(&data[i], &data[c]); + let d = distance_fn(&data[i], &data[c as usize]); // Check if this is better than the worst current neighbor if let Some(worst) = heaps[i].peek() { @@ -432,16 +443,16 @@ where /// Find connected components in the ANN graph using DFS /// Returns the undirected graph adjacency list (sorted vecs) and component assignments -fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { +fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { let n = ann_graph.n(); // Build undirected graph from directed ANN graph using sorted vecs - let mut graph: Vec> = vec![Vec::new(); n]; + let mut graph: Vec> = vec![Vec::new(); n]; for i in 0..n { for neighbor in ann_graph.neighbors(i) { let j = neighbor.index; graph[i].push(j); - graph[j].push(i); + graph[j as usize].push(i as NodeId); } } // Sort and deduplicate each adjacency list @@ -452,7 +463,7 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { // DFS to find components let mut visited = vec![false; n]; - let mut components: Vec> = Vec::new(); + let mut components: Vec> = Vec::new(); for start in 0..n { if visited[start] { @@ -460,17 +471,17 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { } let mut component = Vec::new(); - let mut stack = vec![start]; + let mut stack = vec![start as NodeId]; while let Some(u) = stack.pop() { - if visited[u] { + if visited[u as usize] { continue; } - visited[u] = true; + visited[u as usize] = true; component.push(u); - for &v in &graph[u] { - if !visited[v] { + for &v in &graph[u as usize] { + if !visited[v as usize] { stack.push(v); } } @@ -485,13 +496,13 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { /// Add random edges between components (Algorithm 3 in the paper) fn add_random_edges( data: &[T], - components: &[Vec], + components: &[Vec], lambda: usize, distance_fn: &D, rng: &mut R, ) -> (Vec, Vec<(usize, usize)>) where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, R: Rng, { let t = components.len(); @@ -506,8 +517,8 @@ where // Generate λ² candidate edges for _ in 0..lambda_sq { - let u = *components[i].choose(rng).unwrap(); - let v = *components[j].choose(rng).unwrap(); + let u = *components[i].choose(rng).unwrap() as usize; + let v = *components[j].choose(rng).unwrap() as usize; let d = distance_fn(&data[u], &data[v]); candidates.push(Edge::new(u, v, d)); } @@ -528,14 +539,14 @@ where /// Refine inter-component edges (Algorithm 4 in the paper) fn refine_edges( data: &[T], - undirected_graph: &[Vec], - components: &[Vec], + undirected_graph: &[Vec], + components: &[Vec], edges: &[Edge], edge_components: &[(usize, usize)], distance_fn: &D, ) -> (Vec, usize) where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { let n = data.len(); @@ -543,7 +554,7 @@ where let mut node_to_component: Vec = vec![0; n]; for (comp_idx, component) in components.iter().enumerate() { for &node in component { - node_to_component[node] = comp_idx; + node_to_component[node as usize] = comp_idx; } } @@ -558,6 +569,7 @@ where // Get neighbors of u that are in component ci // (undirected_graph is sorted, so we can iterate directly) for &u_prime in &undirected_graph[edge.u] { + let u_prime = u_prime as usize; if node_to_component[u_prime] == ci && u_prime != edge.v { let d_prime = distance_fn(&data[u_prime], &data[best_v]); if d_prime < best_d { @@ -569,6 +581,7 @@ where // Get neighbors of v that are in component cj for &v_prime in &undirected_graph[edge.v] { + let v_prime = v_prime as usize; if node_to_component[v_prime] == cj && v_prime != edge.u { let d_prime = distance_fn(&data[best_u], &data[v_prime]); if d_prime < best_d { @@ -595,7 +608,7 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec f64 { + /// Manhattan distance for slices of f32 + fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum() } - /// Euclidean distance for slices of f64 - pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + /// Euclidean distance for slices of f32 + fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { a.iter() .zip(b.iter()) .map(|(x, y)| (x - y).powi(2)) - .sum::() + .sum::() .sqrt() } #[test] fn test_empty_input() { - let points: Vec> = vec![]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let result = famst(&points, distance, &FamstConfig::default()); assert_eq!(result.edges.len(), 0); assert_eq!(result.total_weight, 0.0); @@ -657,8 +670,8 @@ mod tests { #[test] fn test_single_point() { - let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![vec![1.0, 2.0, 3.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let result = famst(&points, distance, &FamstConfig::default()); assert_eq!(result.edges.len(), 0); assert_eq!(result.total_weight, 0.0); @@ -667,8 +680,8 @@ mod tests { #[test] fn test_k_greater_than_n() { // 3 points but k=20 (default), so k >= n - let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]]; + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig::default(); // k=20 > n=3 let result = famst(&points, distance, &config); assert_eq!(result.edges.len(), 2); // MST has n-1 edges @@ -687,13 +700,13 @@ mod tests { #[test] fn test_simple_mst() { // Simple 2D points forming a triangle - let points: Vec> = vec![ + let points: Vec> = vec![ vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.866], // Equilateral triangle ]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -708,9 +721,9 @@ mod tests { #[test] fn test_line_points() { // Points on a line - let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; + let points: Vec> = vec![vec![0.0], vec![1.0], vec![2.0], vec![3.0], vec![4.0]]; - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -727,7 +740,7 @@ mod tests { #[test] fn test_disconnected_components() { // Two clusters far apart - let points: Vec> = vec![ + let points: Vec> = vec![ vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.5], @@ -737,7 +750,7 @@ mod tests { ]; // k=1 will likely create disconnected components - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 1, lambda: 3, @@ -754,9 +767,9 @@ mod tests { #[test] fn test_custom_distance() { // Test with Manhattan distance - let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; + let points: Vec> = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; - let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); + let distance = |a: &Vec, b: &Vec| manhattan_distance(a, b); let config = FamstConfig { k: 2, ..Default::default() @@ -775,12 +788,12 @@ mod tests { // Test with a custom struct #[derive(Clone)] struct Point3D { - x: f64, - y: f64, - z: f64, + x: f32, + y: f32, + z: f32, } - fn point_distance(a: &Point3D, b: &Point3D) -> f64 { + fn point_distance(a: &Point3D, b: &Point3D) -> f32 { ((a.x - b.x).powi(2) + (a.y - b.y).powi(2) + (a.z - b.z).powi(2)).sqrt() } @@ -836,7 +849,7 @@ mod tests { ]; let points_per_cluster = 20; - let mut points: Vec> = Vec::new(); + let mut points: Vec> = Vec::new(); for center in &cluster_centers { for _ in 0..points_per_cluster { @@ -849,7 +862,7 @@ mod tests { } let n = points.len(); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); // Use small k to create disconnected components // With k=3 and 20 points per cluster spread over 5 clusters, @@ -899,9 +912,9 @@ mod tests { } /// Compute exact MST using Kruskal's algorithm on complete graph - fn exact_mst_weight(data: &[T], distance_fn: D) -> f64 + fn exact_mst_weight(data: &[T], distance_fn: D) -> f32 where - D: Fn(&T, &T) -> f64, + D: Fn(&T, &T) -> f32, { let n = data.len(); if n <= 1 { @@ -909,7 +922,7 @@ mod tests { } // Build all edges - let mut edges: Vec<(usize, usize, f64)> = Vec::with_capacity(n * (n - 1) / 2); + let mut edges: Vec<(usize, usize, f32)> = Vec::with_capacity(n * (n - 1) / 2); for i in 0..n { for j in (i + 1)..n { let d = distance_fn(&data[i], &data[j]); @@ -950,12 +963,12 @@ mod tests { let mut rng = StdRng::seed_from_u64(12345); let dist = Uniform::new(0.0, 1000.0); - let points: Vec> = (0..N) + let points: Vec> = (0..N) .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) .collect(); println!("Running FAMST with NN-Descent..."); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); let config = FamstConfig { k: 20, lambda: 5, @@ -985,11 +998,11 @@ mod tests { let mut rng = StdRng::seed_from_u64(99999); let dist = Uniform::new(0.0, 100.0); - let points: Vec> = (0..N) + let points: Vec> = (0..N) .map(|_| (0..DIM).map(|_| dist.sample(&mut rng)).collect()) .collect(); - let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); + let distance = |a: &Vec, b: &Vec| euclidean_distance(a, b); // Compute exact MST let exact_weight = exact_mst_weight(&points, distance); From 2940252c6ab15aa0ae24daf4a7ce228feae36cd5 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 07:46:47 +0100 Subject: [PATCH 13/29] extract nn_descent into separate file --- crates/famst/src/lib.rs | 188 ++------------------------------- crates/famst/src/nn_descent.rs | 188 +++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 180 deletions(-) create mode 100644 crates/famst/src/nn_descent.rs diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index a29a014..9f15cd5 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -11,9 +11,11 @@ //! //! Generic over data type `T` and distance function. +mod nn_descent; + +use nn_descent::nn_descent; use rand::seq::SliceRandom; use rand::Rng; -use std::collections::BinaryHeap; /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] @@ -70,13 +72,13 @@ impl UnionFind { } /// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) -type NodeId = u32; +pub(crate) type NodeId = u32; /// A neighbor entry: node index and distance (32-bit for memory efficiency) #[derive(Debug, Clone, Copy)] -struct Neighbor { - index: NodeId, - distance: f32, +pub(crate) struct Neighbor { + pub(crate) index: NodeId, + pub(crate) distance: f32, } impl PartialEq for Neighbor { @@ -104,7 +106,7 @@ impl Ord for Neighbor { /// Approximate Nearest Neighbors graph representation /// Stored as a flat n×k matrix of Neighbor entries -struct AnnGraph { +pub(crate) struct AnnGraph { /// Flat storage: data[i*k..(i+1)*k] contains k neighbors of point i data: Vec, /// Number of points @@ -267,180 +269,6 @@ where } } -/// NN-Descent algorithm for approximate k-NN graph construction -/// -/// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" -/// by Wei Dong, Charikar Moses, and Kai Li (2011) -fn nn_descent(data: &[T], distance_fn: &D, config: &FamstConfig, rng: &mut R) -> AnnGraph -where - D: Fn(&T, &T) -> f32, - R: Rng, -{ - let n = data.len(); - let k = config.k.min(n - 1); - - if k == 0 || n <= 1 { - return AnnGraph::new(n, 0, vec![]); - } - - // Helper: check if sorted vec contains value - fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { - v.binary_search(&x).is_ok() - } - - // Helper: insert into sorted vec, returns true if inserted (was not present) - fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { - match v.binary_search(&x) { - Ok(_) => false, - Err(pos) => { - v.insert(pos, x); - true - } - } - } - - // Helper: remove from sorted vec - fn sorted_remove(v: &mut Vec, x: NodeId) { - if let Ok(pos) = v.binary_search(&x) { - v.remove(pos); - } - } - - // Initialize with random neighbors using max-heap for each point - // neighbor_lists[i] is kept sorted by index for O(log k) membership tests - let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; - - for i in 0..n { - let mut heap = BinaryHeap::with_capacity(k); - - // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) - // https://fermatslibrary.com/s/a-sample-of-brilliance - // This selects k distinct elements from 0..n, excluding i - let effective_n = n - 1; // exclude self - let range_start = effective_n.saturating_sub(k); - for t in range_start..effective_n { - let j = rng.gen_range(0..=t); - // Map j to actual index, skipping i - let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; - - if !sorted_insert(&mut neighbor_lists[i], actual_j) { - // j was already selected, so add t instead - let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; - sorted_insert(&mut neighbor_lists[i], actual_t); - let d = distance_fn(&data[i], &data[actual_t as usize]); - heap.push(Neighbor { - index: actual_t, - distance: d, - }); - } else { - let d = distance_fn(&data[i], &data[actual_j as usize]); - heap.push(Neighbor { - index: actual_j, - distance: d, - }); - } - } - heaps.push(heap); - } - - // Build reverse neighbor lists (who has me as a neighbor) - // Returns sorted vecs for each point - let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { - let mut reverse: Vec> = vec![Vec::new(); n]; - for (i, neighbors) in neighbor_lists.iter().enumerate() { - for &j in neighbors { - reverse[j as usize].push(i as NodeId); - } - } - // Sort each reverse list (they're built in order of i, so already sorted) - reverse - }; - - // NN-Descent iterations - for _ in 0..config.nn_descent_iterations { - let mut updates = 0; - let reverse_neighbors = build_reverse(&neighbor_lists); - - // For each point, explore neighbors of neighbors - for i in 0..n { - // Collect candidates: neighbors and reverse neighbors - let mut candidates: Vec = Vec::new(); - - // Sample from forward neighbors - let mut sampled_forward = neighbor_lists[i].clone(); - let sample_size = - ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - sampled_forward.shuffle(rng); - sampled_forward.truncate(sample_size); - - // Sample from reverse neighbors - let mut sampled_reverse = reverse_neighbors[i].clone(); - let sample_size = - ((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize).max(1); - sampled_reverse.shuffle(rng); - sampled_reverse.truncate(sample_size); - - // Neighbors of neighbors - let i_id = i as NodeId; - for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_lists[neighbor as usize] { - if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) { - candidates.push(nn); - } - } - // Also check reverse neighbors of neighbors - for &rn in &reverse_neighbors[neighbor as usize] { - if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) { - candidates.push(rn); - } - } - } - - // Deduplicate candidates - candidates.sort_unstable(); - candidates.dedup(); - - // Try to improve neighbors - for c in candidates { - let d = distance_fn(&data[i], &data[c as usize]); - - // Check if this is better than the worst current neighbor - if let Some(worst) = heaps[i].peek() { - if d < worst.distance { - // Remove worst and add new neighbor - let removed = heaps[i].pop().unwrap(); - sorted_remove(&mut neighbor_lists[i], removed.index); - - heaps[i].push(Neighbor { - index: c, - distance: d, - }); - sorted_insert(&mut neighbor_lists[i], c); - updates += 1; - } - } - } - } - - // Early termination if no updates - if updates == 0 { - break; - } - } - - // Convert heaps to flat neighbor array sorted by distance - let mut result_data = Vec::with_capacity(n * k); - - for heap in heaps { - let mut entries: Vec = heap.into_vec(); - entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - result_data.extend(entries); - } - - AnnGraph::new(n, k, result_data) -} - /// Find connected components in the ANN graph using DFS /// Returns the undirected graph adjacency list (sorted vecs) and component assignments fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs new file mode 100644 index 0000000..e797f7d --- /dev/null +++ b/crates/famst/src/nn_descent.rs @@ -0,0 +1,188 @@ +//! NN-Descent algorithm for approximate k-NN graph construction +//! +//! Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" +//! by Wei Dong, Charikar Moses, and Kai Li (2011) + +use rand::seq::SliceRandom; +use rand::Rng; +use std::collections::BinaryHeap; + +use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; + +/// NN-Descent algorithm for approximate k-NN graph construction +/// +/// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" +/// by Wei Dong, Charikar Moses, and Kai Li (2011) +pub(crate) fn nn_descent( + data: &[T], + distance_fn: &D, + config: &FamstConfig, + rng: &mut R, +) -> AnnGraph +where + D: Fn(&T, &T) -> f32, + R: Rng, +{ + let n = data.len(); + let k = config.k.min(n - 1); + + if k == 0 || n <= 1 { + return AnnGraph::new(n, 0, vec![]); + } + + // Helper: check if sorted vec contains value + fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { + v.binary_search(&x).is_ok() + } + + // Helper: insert into sorted vec, returns true if inserted (was not present) + fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { + match v.binary_search(&x) { + Ok(_) => false, + Err(pos) => { + v.insert(pos, x); + true + } + } + } + + // Helper: remove from sorted vec + fn sorted_remove(v: &mut Vec, x: NodeId) { + if let Ok(pos) = v.binary_search(&x) { + v.remove(pos); + } + } + + // Initialize with random neighbors using max-heap for each point + // neighbor_lists[i] is kept sorted by index for O(log k) membership tests + let mut heaps: Vec> = Vec::with_capacity(n); + let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; + + for i in 0..n { + let mut heap = BinaryHeap::with_capacity(k); + + // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) + // https://fermatslibrary.com/s/a-sample-of-brilliance + // This selects k distinct elements from 0..n, excluding i + let effective_n = n - 1; // exclude self + let range_start = effective_n.saturating_sub(k); + for t in range_start..effective_n { + let j = rng.gen_range(0..=t); + // Map j to actual index, skipping i + let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; + + if !sorted_insert(&mut neighbor_lists[i], actual_j) { + // j was already selected, so add t instead + let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; + sorted_insert(&mut neighbor_lists[i], actual_t); + let d = distance_fn(&data[i], &data[actual_t as usize]); + heap.push(Neighbor { + index: actual_t, + distance: d, + }); + } else { + let d = distance_fn(&data[i], &data[actual_j as usize]); + heap.push(Neighbor { + index: actual_j, + distance: d, + }); + } + } + heaps.push(heap); + } + + // Build reverse neighbor lists (who has me as a neighbor) + // Returns sorted vecs for each point + let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; + for (i, neighbors) in neighbor_lists.iter().enumerate() { + for &j in neighbors { + reverse[j as usize].push(i as NodeId); + } + } + // Sort each reverse list (they're built in order of i, so already sorted) + reverse + }; + + // NN-Descent iterations + for _ in 0..config.nn_descent_iterations { + let mut updates = 0; + let reverse_neighbors = build_reverse(&neighbor_lists); + + // For each point, explore neighbors of neighbors + for i in 0..n { + // Collect candidates: neighbors and reverse neighbors + let mut candidates: Vec = Vec::new(); + + // Sample from forward neighbors + let mut sampled_forward = neighbor_lists[i].clone(); + let sample_size = + ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize) + .max(1); + sampled_forward.shuffle(rng); + sampled_forward.truncate(sample_size); + + // Sample from reverse neighbors + let mut sampled_reverse = reverse_neighbors[i].clone(); + let sample_size = + ((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize) + .max(1); + sampled_reverse.shuffle(rng); + sampled_reverse.truncate(sample_size); + + // Neighbors of neighbors + let i_id = i as NodeId; + for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { + for &nn in &neighbor_lists[neighbor as usize] { + if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) { + candidates.push(nn); + } + } + // Also check reverse neighbors of neighbors + for &rn in &reverse_neighbors[neighbor as usize] { + if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) { + candidates.push(rn); + } + } + } + + // Deduplicate candidates + candidates.sort_unstable(); + candidates.dedup(); + + // Try to improve neighbors + for c in candidates { + let d = distance_fn(&data[i], &data[c as usize]); + + // Check if this is better than the worst current neighbor + if let Some(worst) = heaps[i].peek() { + if d < worst.distance { + // Remove worst and add new neighbor + let removed = heaps[i].pop().unwrap(); + sorted_remove(&mut neighbor_lists[i], removed.index); + + heaps[i].push(Neighbor { index: c, distance: d }); + sorted_insert(&mut neighbor_lists[i], c); + updates += 1; + } + } + } + } + + // Early termination if no updates + if updates == 0 { + break; + } + } + + // Convert heaps to flat neighbor array sorted by distance + let mut result_data = Vec::with_capacity(n * k); + + for heap in heaps { + let mut entries: Vec = heap.into_vec(); + entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + result_data.extend(entries); + } + + AnnGraph::new(n, k, result_data) +} From 2d42e40c98b52ed04fc9f38153c9af4014861a0a Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 07:50:28 +0100 Subject: [PATCH 14/29] refactor a bit more --- crates/famst/src/lib.rs | 42 +------------------ crates/famst/src/nn_descent.rs | 74 +++++++++++++++++----------------- crates/famst/src/union_find.rs | 41 +++++++++++++++++++ 3 files changed, 80 insertions(+), 77 deletions(-) create mode 100644 crates/famst/src/union_find.rs diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 9f15cd5..14bcee6 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -12,10 +12,12 @@ //! Generic over data type `T` and distance function. mod nn_descent; +mod union_find; use nn_descent::nn_descent; use rand::seq::SliceRandom; use rand::Rng; +use union_find::UnionFind; /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] @@ -31,46 +33,6 @@ impl Edge { } } -/// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm -struct UnionFind { - parent: Vec, - rank: Vec, -} - -impl UnionFind { - fn new(n: usize) -> Self { - UnionFind { - parent: (0..n).collect(), - rank: vec![0; n], - } - } - - fn find(&mut self, x: usize) -> usize { - if self.parent[x] != x { - self.parent[x] = self.find(self.parent[x]); // Path compression - } - self.parent[x] - } - - fn union(&mut self, x: usize, y: usize) -> bool { - let px = self.find(x); - let py = self.find(y); - if px == py { - return false; - } - // Union by rank - match self.rank[px].cmp(&self.rank[py]) { - std::cmp::Ordering::Less => self.parent[px] = py, - std::cmp::Ordering::Greater => self.parent[py] = px, - std::cmp::Ordering::Equal => { - self.parent[py] = px; - self.rank[px] += 1; - } - } - true - } -} - /// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) pub(crate) type NodeId = u32; diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index e797f7d..e72f827 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -9,6 +9,42 @@ use std::collections::BinaryHeap; use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; +/// Check if sorted vec contains value +fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { + v.binary_search(&x).is_ok() +} + +/// Insert into sorted vec, returns true if inserted (was not present) +fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { + match v.binary_search(&x) { + Ok(_) => false, + Err(pos) => { + v.insert(pos, x); + true + } + } +} + +/// Remove from sorted vec +fn sorted_remove(v: &mut Vec, x: NodeId) { + if let Ok(pos) = v.binary_search(&x) { + v.remove(pos); + } +} + +/// Build reverse neighbor lists (who has me as a neighbor) +/// Returns sorted vecs for each point +fn build_reverse(neighbor_lists: &[Vec], n: usize) -> Vec> { + let mut reverse: Vec> = vec![Vec::new(); n]; + for (i, neighbors) in neighbor_lists.iter().enumerate() { + for &j in neighbors { + reverse[j as usize].push(i as NodeId); + } + } + // Each reverse list is built in order of i, so already sorted + reverse +} + /// NN-Descent algorithm for approximate k-NN graph construction /// /// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" @@ -30,29 +66,6 @@ where return AnnGraph::new(n, 0, vec![]); } - // Helper: check if sorted vec contains value - fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { - v.binary_search(&x).is_ok() - } - - // Helper: insert into sorted vec, returns true if inserted (was not present) - fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { - match v.binary_search(&x) { - Ok(_) => false, - Err(pos) => { - v.insert(pos, x); - true - } - } - } - - // Helper: remove from sorted vec - fn sorted_remove(v: &mut Vec, x: NodeId) { - if let Ok(pos) = v.binary_search(&x) { - v.remove(pos); - } - } - // Initialize with random neighbors using max-heap for each point // neighbor_lists[i] is kept sorted by index for O(log k) membership tests let mut heaps: Vec> = Vec::with_capacity(n); @@ -91,23 +104,10 @@ where heaps.push(heap); } - // Build reverse neighbor lists (who has me as a neighbor) - // Returns sorted vecs for each point - let build_reverse = |neighbor_lists: &[Vec]| -> Vec> { - let mut reverse: Vec> = vec![Vec::new(); n]; - for (i, neighbors) in neighbor_lists.iter().enumerate() { - for &j in neighbors { - reverse[j as usize].push(i as NodeId); - } - } - // Sort each reverse list (they're built in order of i, so already sorted) - reverse - }; - // NN-Descent iterations for _ in 0..config.nn_descent_iterations { let mut updates = 0; - let reverse_neighbors = build_reverse(&neighbor_lists); + let reverse_neighbors = build_reverse(&neighbor_lists, n); // For each point, explore neighbors of neighbors for i in 0..n { diff --git a/crates/famst/src/union_find.rs b/crates/famst/src/union_find.rs new file mode 100644 index 0000000..de69ec0 --- /dev/null +++ b/crates/famst/src/union_find.rs @@ -0,0 +1,41 @@ +//! Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm + +/// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm +pub(crate) struct UnionFind { + parent: Vec, + rank: Vec, +} + +impl UnionFind { + pub(crate) fn new(n: usize) -> Self { + UnionFind { + parent: (0..n).collect(), + rank: vec![0; n], + } + } + + pub(crate) fn find(&mut self, x: usize) -> usize { + if self.parent[x] != x { + self.parent[x] = self.find(self.parent[x]); // Path compression + } + self.parent[x] + } + + pub(crate) fn union(&mut self, x: usize, y: usize) -> bool { + let px = self.find(x); + let py = self.find(y); + if px == py { + return false; + } + // Union by rank + match self.rank[px].cmp(&self.rank[py]) { + std::cmp::Ordering::Less => self.parent[px] = py, + std::cmp::Ordering::Greater => self.parent[py] = px, + std::cmp::Ordering::Equal => { + self.parent[py] = px; + self.rank[px] += 1; + } + } + true + } +} From 459062706779a074c232977c14229ff6f1ce31c4 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 07:56:30 +0100 Subject: [PATCH 15/29] use NodeId type everywhere --- crates/famst/src/lib.rs | 36 +++++++++++++++++----------------- crates/famst/src/union_find.rs | 28 ++++++++++++++------------ 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 14bcee6..4196eea 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -19,23 +19,23 @@ use rand::seq::SliceRandom; use rand::Rng; use union_find::UnionFind; +/// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) +pub type NodeId = u32; + /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] pub struct Edge { - pub u: usize, - pub v: usize, + pub u: NodeId, + pub v: NodeId, pub distance: f32, } impl Edge { - fn new(u: usize, v: usize, distance: f32) -> Self { + fn new(u: NodeId, v: NodeId, distance: f32) -> Self { Edge { u, v, distance } } } -/// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) -pub(crate) type NodeId = u32; - /// A neighbor entry: node index and distance (32-bit for memory efficiency) #[derive(Debug, Clone, Copy)] pub(crate) struct Neighbor { @@ -307,9 +307,9 @@ where // Generate λ² candidate edges for _ in 0..lambda_sq { - let u = *components[i].choose(rng).unwrap() as usize; - let v = *components[j].choose(rng).unwrap() as usize; - let d = distance_fn(&data[u], &data[v]); + let u = *components[i].choose(rng).unwrap(); + let v = *components[j].choose(rng).unwrap(); + let d = distance_fn(&data[u as usize], &data[v as usize]); candidates.push(Edge::new(u, v, d)); } @@ -358,10 +358,10 @@ where // Get neighbors of u that are in component ci // (undirected_graph is sorted, so we can iterate directly) - for &u_prime in &undirected_graph[edge.u] { - let u_prime = u_prime as usize; - if node_to_component[u_prime] == ci && u_prime != edge.v { - let d_prime = distance_fn(&data[u_prime], &data[best_v]); + for &u_prime in &undirected_graph[edge.u as usize] { + let u_prime_usize = u_prime as usize; + if node_to_component[u_prime_usize] == ci && u_prime != edge.v { + let d_prime = distance_fn(&data[u_prime_usize], &data[best_v as usize]); if d_prime < best_d { best_u = u_prime; best_d = d_prime; @@ -370,10 +370,10 @@ where } // Get neighbors of v that are in component cj - for &v_prime in &undirected_graph[edge.v] { - let v_prime = v_prime as usize; - if node_to_component[v_prime] == cj && v_prime != edge.u { - let d_prime = distance_fn(&data[best_u], &data[v_prime]); + for &v_prime in &undirected_graph[edge.v as usize] { + let v_prime_usize = v_prime as usize; + if node_to_component[v_prime_usize] == cj && v_prime != edge.u { + let d_prime = distance_fn(&data[best_u as usize], &data[v_prime_usize]); if d_prime < best_d { best_v = v_prime; best_d = d_prime; @@ -398,7 +398,7 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec, - rank: Vec, + parent: Vec, + rank: Vec, } impl UnionFind { pub(crate) fn new(n: usize) -> Self { UnionFind { - parent: (0..n).collect(), + parent: (0..n as NodeId).collect(), rank: vec![0; n], } } - pub(crate) fn find(&mut self, x: usize) -> usize { - if self.parent[x] != x { - self.parent[x] = self.find(self.parent[x]); // Path compression + pub(crate) fn find(&mut self, x: NodeId) -> NodeId { + if self.parent[x as usize] != x { + self.parent[x as usize] = self.find(self.parent[x as usize]); // Path compression } - self.parent[x] + self.parent[x as usize] } - pub(crate) fn union(&mut self, x: usize, y: usize) -> bool { + pub(crate) fn union(&mut self, x: NodeId, y: NodeId) -> bool { let px = self.find(x); let py = self.find(y); if px == py { return false; } // Union by rank - match self.rank[px].cmp(&self.rank[py]) { - std::cmp::Ordering::Less => self.parent[px] = py, - std::cmp::Ordering::Greater => self.parent[py] = px, + match self.rank[px as usize].cmp(&self.rank[py as usize]) { + std::cmp::Ordering::Less => self.parent[px as usize] = py, + std::cmp::Ordering::Greater => self.parent[py as usize] = px, std::cmp::Ordering::Equal => { - self.parent[py] = px; - self.rank[px] += 1; + self.parent[py as usize] = px; + self.rank[px as usize] += 1; } } true From 7f8ee78c50de749a91af702dfb5d4c8add7f9933 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 08:00:46 +0100 Subject: [PATCH 16/29] Update lib.rs --- crates/famst/src/lib.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 4196eea..f824368 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -182,7 +182,8 @@ where let ann_graph = nn_descent(data, &distance_fn, config, rng); // Phase 2: Build undirected graph and find connected components - let (undirected_graph, components) = find_components(&ann_graph); + let undirected_graph = build_undirected_graph(&ann_graph); + let components = find_components(&undirected_graph); // If only one component, skip inter-component edge logic if components.len() <= 1 { @@ -231,12 +232,11 @@ where } } -/// Find connected components in the ANN graph using DFS -/// Returns the undirected graph adjacency list (sorted vecs) and component assignments -fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) { +/// Build an undirected graph from the directed ANN graph +/// Returns adjacency list with sorted, deduplicated neighbors +fn build_undirected_graph(ann_graph: &AnnGraph) -> Vec> { let n = ann_graph.n(); - // Build undirected graph from directed ANN graph using sorted vecs let mut graph: Vec> = vec![Vec::new(); n]; for i in 0..n { for neighbor in ann_graph.neighbors(i) { @@ -251,7 +251,14 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) adj.dedup(); } - // DFS to find components + graph +} + +/// Find connected components in an undirected graph using DFS +/// Returns component assignments as a list of node lists +fn find_components(graph: &[Vec]) -> Vec> { + let n = graph.len(); + let mut visited = vec![false; n]; let mut components: Vec> = Vec::new(); @@ -280,7 +287,7 @@ fn find_components(ann_graph: &AnnGraph) -> (Vec>, Vec>) components.push(component); } - (graph, components) + components } /// Add random edges between components (Algorithm 3 in the paper) From ea848b53ab72cf4565b8de22f8f071d02f46b465 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 08:09:10 +0100 Subject: [PATCH 17/29] improve find_components a bit --- crates/famst/src/lib.rs | 40 ++-------------------------------- crates/famst/src/union_find.rs | 38 +++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index f824368..c6c711f 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -17,7 +17,7 @@ mod union_find; use nn_descent::nn_descent; use rand::seq::SliceRandom; use rand::Rng; -use union_find::UnionFind; +use union_find::{find_components, UnionFind}; /// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) pub type NodeId = u32; @@ -183,7 +183,7 @@ where // Phase 2: Build undirected graph and find connected components let undirected_graph = build_undirected_graph(&ann_graph); - let components = find_components(&undirected_graph); + let components = find_components(&ann_graph); // If only one component, skip inter-component edge logic if components.len() <= 1 { @@ -254,42 +254,6 @@ fn build_undirected_graph(ann_graph: &AnnGraph) -> Vec> { graph } -/// Find connected components in an undirected graph using DFS -/// Returns component assignments as a list of node lists -fn find_components(graph: &[Vec]) -> Vec> { - let n = graph.len(); - - let mut visited = vec![false; n]; - let mut components: Vec> = Vec::new(); - - for start in 0..n { - if visited[start] { - continue; - } - - let mut component = Vec::new(); - let mut stack = vec![start as NodeId]; - - while let Some(u) = stack.pop() { - if visited[u as usize] { - continue; - } - visited[u as usize] = true; - component.push(u); - - for &v in &graph[u as usize] { - if !visited[v as usize] { - stack.push(v); - } - } - } - - components.push(component); - } - - components -} - /// Add random edges between components (Algorithm 3 in the paper) fn add_random_edges( data: &[T], diff --git a/crates/famst/src/union_find.rs b/crates/famst/src/union_find.rs index a76358b..ee5c75d 100644 --- a/crates/famst/src/union_find.rs +++ b/crates/famst/src/union_find.rs @@ -1,6 +1,6 @@ //! Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm -use crate::NodeId; +use crate::{AnnGraph, NodeId}; /// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm pub(crate) struct UnionFind { @@ -41,3 +41,39 @@ impl UnionFind { true } } + +/// Find connected components from the ANN graph +/// Treats directed edges as undirected for connectivity +/// Returns component assignments as a list of node lists +pub(crate) fn find_components(ann_graph: &AnnGraph) -> Vec> { + let n = ann_graph.n(); + let mut uf = UnionFind::new(n); + + // Union all edges (treating directed as undirected) + for i in 0..n { + for neighbor in ann_graph.neighbors(i) { + uf.union(i as NodeId, neighbor.index); + } + } + + // First pass: assign contiguous component IDs to each root + let mut root_to_component: Vec = vec![NodeId::MAX; n]; + let mut num_components = 0; + for i in 0..n { + let root = uf.find(i as NodeId) as usize; + if root_to_component[root] == NodeId::MAX { + root_to_component[root] = num_components; + num_components += 1; + } + } + + // Second pass: fill components directly + let mut components: Vec> = vec![Vec::new(); num_components as usize]; + for i in 0..n { + let root = uf.find(i as NodeId) as usize; + let component_id = root_to_component[root] as usize; + components[component_id].push(i as NodeId); + } + + components +} From 0014f86e0dd98a8e7b988ba464f6b7ed5312e024 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 08:54:13 +0100 Subject: [PATCH 18/29] more simplifications --- crates/famst/src/lib.rs | 21 +++- crates/famst/src/nn_descent.rs | 182 +++++++++++++++++---------------- 2 files changed, 108 insertions(+), 95 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index c6c711f..846bbe0 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -45,7 +45,7 @@ pub(crate) struct Neighbor { impl PartialEq for Neighbor { fn eq(&self, other: &Self) -> bool { - self.distance == other.distance + self.distance == other.distance && self.index == other.index } } @@ -59,10 +59,11 @@ impl PartialOrd for Neighbor { impl Ord for Neighbor { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Max-heap: larger distances have higher priority + // Total ordering by (distance, index) self.distance .partial_cmp(&other.distance) .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| self.index.cmp(&other.index)) } } @@ -78,20 +79,30 @@ pub(crate) struct AnnGraph { } impl AnnGraph { - fn new(n: usize, k: usize, data: Vec) -> Self { + pub(crate) fn new(n: usize, k: usize, data: Vec) -> Self { assert_eq!(data.len(), n * k); AnnGraph { data, n, k } } - fn n(&self) -> usize { + pub(crate) fn n(&self) -> usize { self.n } + pub(crate) fn k(&self) -> usize { + self.k + } + /// Get the neighbors of point i - fn neighbors(&self, i: usize) -> &[Neighbor] { + pub(crate) fn neighbors(&self, i: usize) -> &[Neighbor] { let start = i * self.k; &self.data[start..start + self.k] } + + /// Get mutable access to neighbors of point i + pub(crate) fn neighbors_mut(&mut self, i: usize) -> &mut [Neighbor] { + let start = i * self.k; + &mut self.data[start..start + self.k] + } } /// FAMST algorithm configuration diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index e72f827..bed831c 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -5,44 +5,94 @@ use rand::seq::SliceRandom; use rand::Rng; -use std::collections::BinaryHeap; +use std::collections::HashSet; use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; -/// Check if sorted vec contains value -fn sorted_contains(v: &[NodeId], x: NodeId) -> bool { - v.binary_search(&x).is_ok() -} - -/// Insert into sorted vec, returns true if inserted (was not present) -fn sorted_insert(v: &mut Vec, x: NodeId) -> bool { - match v.binary_search(&x) { - Ok(_) => false, - Err(pos) => { - v.insert(pos, x); - true +/// Build reverse neighbor lists (who has me as a neighbor) +fn build_reverse(graph: &AnnGraph) -> Vec> { + let n = graph.n(); + let mut reverse: Vec> = vec![Vec::new(); n]; + for i in 0..n { + for neighbor in graph.neighbors(i) { + reverse[neighbor.index as usize].push(i as NodeId); } } + reverse } -/// Remove from sorted vec -fn sorted_remove(v: &mut Vec, x: NodeId) { - if let Ok(pos) = v.binary_search(&x) { - v.remove(pos); +/// Initialize ANN graph with random neighbors +fn init_random_graph(data: &[T], k: usize, distance_fn: &D, rng: &mut R) -> AnnGraph +where + D: Fn(&T, &T) -> f32, + R: Rng, +{ + let n = data.len(); + let mut graph_data: Vec = Vec::with_capacity(n * k); + + for i in 0..n { + let mut neighbors: Vec = Vec::with_capacity(k); + let mut seen: HashSet = HashSet::with_capacity(k); + + // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) + let effective_n = n - 1; // exclude self + let range_start = effective_n.saturating_sub(k); + for t in range_start..effective_n { + let j = rng.gen_range(0..=t); + // Map j to actual index, skipping i + let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; + + let selected = if seen.insert(actual_j) { + actual_j + } else { + // j was already selected, so add t instead + let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; + seen.insert(actual_t); + actual_t + }; + + let d = distance_fn(&data[i], &data[selected as usize]); + neighbors.push(Neighbor { + index: selected, + distance: d, + }); + } + + // Sort by (distance, index) for total ordering + neighbors.sort(); + graph_data.extend(neighbors); } + + AnnGraph::new(n, k, graph_data) } -/// Build reverse neighbor lists (who has me as a neighbor) -/// Returns sorted vecs for each point -fn build_reverse(neighbor_lists: &[Vec], n: usize) -> Vec> { - let mut reverse: Vec> = vec![Vec::new(); n]; - for (i, neighbors) in neighbor_lists.iter().enumerate() { - for &j in neighbors { - reverse[j as usize].push(i as NodeId); +/// Try to insert a new neighbor into a sorted neighbor slice. +/// Returns true if the neighbor was inserted (better than the worst). +/// Assumes neighbors are sorted by (distance, index) for total ordering. +fn insert_neighbor(neighbors: &mut [Neighbor], new_index: NodeId, new_distance: f32) -> bool { + let new_neighbor = Neighbor { + index: new_index, + distance: new_distance, + }; + + // Binary search using total ordering - also serves as existence check + match neighbors.binary_search(&new_neighbor) { + Ok(_) => false, // Already exists + Err(insert_pos) => { + // Check if better than worst (last element) + if insert_pos >= neighbors.len() { + return false; + } + + // Shift elements to make room (dropping the last/worst) + for j in (insert_pos + 1..neighbors.len()).rev() { + neighbors[j] = neighbors[j - 1]; + } + + neighbors[insert_pos] = new_neighbor; + true } } - // Each reverse list is built in order of i, so already sorted - reverse } /// NN-Descent algorithm for approximate k-NN graph construction @@ -66,56 +116,26 @@ where return AnnGraph::new(n, 0, vec![]); } - // Initialize with random neighbors using max-heap for each point - // neighbor_lists[i] is kept sorted by index for O(log k) membership tests - let mut heaps: Vec> = Vec::with_capacity(n); - let mut neighbor_lists: Vec> = vec![Vec::with_capacity(k); n]; - - for i in 0..n { - let mut heap = BinaryHeap::with_capacity(k); - - // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) - // https://fermatslibrary.com/s/a-sample-of-brilliance - // This selects k distinct elements from 0..n, excluding i - let effective_n = n - 1; // exclude self - let range_start = effective_n.saturating_sub(k); - for t in range_start..effective_n { - let j = rng.gen_range(0..=t); - // Map j to actual index, skipping i - let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; - - if !sorted_insert(&mut neighbor_lists[i], actual_j) { - // j was already selected, so add t instead - let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; - sorted_insert(&mut neighbor_lists[i], actual_t); - let d = distance_fn(&data[i], &data[actual_t as usize]); - heap.push(Neighbor { - index: actual_t, - distance: d, - }); - } else { - let d = distance_fn(&data[i], &data[actual_j as usize]); - heap.push(Neighbor { - index: actual_j, - distance: d, - }); - } - } - heaps.push(heap); - } + // Initialize ANN graph with random neighbors + let mut graph = init_random_graph(data, k, distance_fn, rng); // NN-Descent iterations for _ in 0..config.nn_descent_iterations { let mut updates = 0; - let reverse_neighbors = build_reverse(&neighbor_lists, n); + let reverse_neighbors = build_reverse(&graph); // For each point, explore neighbors of neighbors for i in 0..n { + // Build set of current neighbors for O(1) lookup + let current_neighbors: HashSet = + graph.neighbors(i).iter().map(|nb| nb.index).collect(); + // Collect candidates: neighbors and reverse neighbors let mut candidates: Vec = Vec::new(); // Sample from forward neighbors - let mut sampled_forward = neighbor_lists[i].clone(); + let mut sampled_forward: Vec = + graph.neighbors(i).iter().map(|nb| nb.index).collect(); let sample_size = ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize) .max(1); @@ -133,14 +153,14 @@ where // Neighbors of neighbors let i_id = i as NodeId; for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for &nn in &neighbor_lists[neighbor as usize] { - if nn != i_id && !sorted_contains(&neighbor_lists[i], nn) { - candidates.push(nn); + for nb in graph.neighbors(neighbor as usize) { + if nb.index != i_id && !current_neighbors.contains(&nb.index) { + candidates.push(nb.index); } } // Also check reverse neighbors of neighbors for &rn in &reverse_neighbors[neighbor as usize] { - if rn != i_id && !sorted_contains(&neighbor_lists[i], rn) { + if rn != i_id && !current_neighbors.contains(&rn) { candidates.push(rn); } } @@ -154,17 +174,8 @@ where for c in candidates { let d = distance_fn(&data[i], &data[c as usize]); - // Check if this is better than the worst current neighbor - if let Some(worst) = heaps[i].peek() { - if d < worst.distance { - // Remove worst and add new neighbor - let removed = heaps[i].pop().unwrap(); - sorted_remove(&mut neighbor_lists[i], removed.index); - - heaps[i].push(Neighbor { index: c, distance: d }); - sorted_insert(&mut neighbor_lists[i], c); - updates += 1; - } + if insert_neighbor(graph.neighbors_mut(i), c, d) { + updates += 1; } } } @@ -175,14 +186,5 @@ where } } - // Convert heaps to flat neighbor array sorted by distance - let mut result_data = Vec::with_capacity(n * k); - - for heap in heaps { - let mut entries: Vec = heap.into_vec(); - entries.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); - result_data.extend(entries); - } - - AnnGraph::new(n, k, result_data) + graph } From dfad6d7a5876331dd6548eb564a608dff6fa34c1 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 09:11:02 +0100 Subject: [PATCH 19/29] fix tests --- crates/famst/src/lib.rs | 13 +++++++------ crates/famst/src/nn_descent.rs | 4 +++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 846bbe0..b495e3d 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -195,6 +195,7 @@ where // Phase 2: Build undirected graph and find connected components let undirected_graph = build_undirected_graph(&ann_graph); let components = find_components(&ann_graph); + println!("Found {} connected components", components.len()); // If only one component, skip inter-component edge logic if components.len() <= 1 { @@ -661,7 +662,7 @@ mod tests { // Check all nodes are in the same component let root = uf.find(0); for i in 1..n { - assert_eq!(uf.find(i), root, "All nodes should be connected in the MST"); + assert_eq!(uf.find(i as NodeId), root, "All nodes should be connected in the MST"); } // Compare with exact MST @@ -711,7 +712,7 @@ mod tests { let mut edge_count = 0; for (u, v, w) in edges { - if uf.union(u, v) { + if uf.union(u as NodeId, v as NodeId) { total_weight += w; edge_count += 1; if edge_count == n - 1 { @@ -764,8 +765,8 @@ mod tests { fn test_medium_scale_vs_exact() { use rand::distributions::{Distribution, Uniform}; - const N: usize = 5000; - const DIM: usize = 5; + const N: usize = 10000; + const DIM: usize = 2; let mut rng = StdRng::seed_from_u64(99999); let dist = Uniform::new(0.0, 100.0); @@ -781,10 +782,10 @@ mod tests { // Compute approximate MST with FAMST let config = FamstConfig { - k: 15, + k: 4, lambda: 5, max_iterations: 100, - nn_descent_iterations: 15, + nn_descent_iterations: 100, nn_descent_sample_rate: 0.5, }; let mut famst_rng = StdRng::seed_from_u64(11111); diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index bed831c..21d7944 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -117,10 +117,12 @@ where } // Initialize ANN graph with random neighbors + println!("Initializing random graph"); let mut graph = init_random_graph(data, k, distance_fn, rng); // NN-Descent iterations - for _ in 0..config.nn_descent_iterations { + for iter in 0..config.nn_descent_iterations { + println!("NN-Descent iteration {iter}"); let mut updates = 0; let reverse_neighbors = build_reverse(&graph); From 2375b361821cd76c8dfa992a89c38d88c685d453 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 09:18:53 +0100 Subject: [PATCH 20/29] start using rayon --- crates/famst/Cargo.toml | 3 +- crates/famst/src/lib.rs | 6 ++- crates/famst/src/nn_descent.rs | 80 +++++++++++++++++++--------------- 3 files changed, 51 insertions(+), 38 deletions(-) diff --git a/crates/famst/Cargo.toml b/crates/famst/Cargo.toml index da26573..62bacd0 100644 --- a/crates/famst/Cargo.toml +++ b/crates/famst/Cargo.toml @@ -6,7 +6,8 @@ description = "Fast Approximate Minimum Spanning Tree (FAMST) algorithm" license = "MIT" [dependencies] -rand = "0.8" +rand = { version = "0.8", features = ["small_rng"] } +rayon = "1" [dev-dependencies] rand = "0.8" diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b495e3d..4e76bfe 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -157,7 +157,8 @@ pub struct FamstResult { /// Panics if `data.len() >= 2^32` (more than ~4 billion points). pub fn famst(data: &[T], distance_fn: D, config: &FamstConfig) -> FamstResult where - D: Fn(&T, &T) -> f32, + T: Sync, + D: Fn(&T, &T) -> f32 + Sync, { famst_with_rng(data, distance_fn, config, &mut rand::thread_rng()) } @@ -173,7 +174,8 @@ pub fn famst_with_rng( rng: &mut R, ) -> FamstResult where - D: Fn(&T, &T) -> f32, + T: Sync, + D: Fn(&T, &T) -> f32 + Sync, R: Rng, { let n = data.len(); diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 21d7944..b485c57 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -4,7 +4,9 @@ //! by Wei Dong, Charikar Moses, and Kai Li (2011) use rand::seq::SliceRandom; -use rand::Rng; +use rand::{Rng, SeedableRng}; +use rand::rngs::SmallRng; +use rayon::prelude::*; use std::collections::HashSet; use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; @@ -24,44 +26,51 @@ fn build_reverse(graph: &AnnGraph) -> Vec> { /// Initialize ANN graph with random neighbors fn init_random_graph(data: &[T], k: usize, distance_fn: &D, rng: &mut R) -> AnnGraph where - D: Fn(&T, &T) -> f32, + T: Sync, + D: Fn(&T, &T) -> f32 + Sync, R: Rng, { let n = data.len(); - let mut graph_data: Vec = Vec::with_capacity(n * k); - for i in 0..n { - let mut neighbors: Vec = Vec::with_capacity(k); - let mut seen: HashSet = HashSet::with_capacity(k); - - // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) - let effective_n = n - 1; // exclude self - let range_start = effective_n.saturating_sub(k); - for t in range_start..effective_n { - let j = rng.gen_range(0..=t); - // Map j to actual index, skipping i - let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; - - let selected = if seen.insert(actual_j) { - actual_j - } else { - // j was already selected, so add t instead - let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; - seen.insert(actual_t); - actual_t - }; - - let d = distance_fn(&data[i], &data[selected as usize]); - neighbors.push(Neighbor { - index: selected, - distance: d, - }); - } + // Generate seeds for per-thread RNGs + let seeds: Vec = (0..n).map(|_| rng.r#gen()).collect(); + + let graph_data: Vec = (0..n) + .into_par_iter() + .flat_map(|i| { + let mut local_rng = SmallRng::seed_from_u64(seeds[i]); + let mut neighbors: Vec = Vec::with_capacity(k); + let mut seen: HashSet = HashSet::with_capacity(k); + + // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) + let effective_n = n - 1; // exclude self + let range_start = effective_n.saturating_sub(k); + for t in range_start..effective_n { + let j = local_rng.gen_range(0..=t); + // Map j to actual index, skipping i + let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; + + let selected = if seen.insert(actual_j) { + actual_j + } else { + // j was already selected, so add t instead + let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; + seen.insert(actual_t); + actual_t + }; + + let d = distance_fn(&data[i], &data[selected as usize]); + neighbors.push(Neighbor { + index: selected, + distance: d, + }); + } - // Sort by (distance, index) for total ordering - neighbors.sort(); - graph_data.extend(neighbors); - } + // Sort by (distance, index) for total ordering + neighbors.sort(); + neighbors + }) + .collect(); AnnGraph::new(n, k, graph_data) } @@ -106,7 +115,8 @@ pub(crate) fn nn_descent( rng: &mut R, ) -> AnnGraph where - D: Fn(&T, &T) -> f32, + T: Sync, + D: Fn(&T, &T) -> f32 + Sync, R: Rng, { let n = data.len(); From 9b45210a1bc4ac5e86efb10e46859c4c9029445c Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 09:48:44 +0100 Subject: [PATCH 21/29] introduce new marker --- crates/famst/src/lib.rs | 110 +++++++++++++++---- crates/famst/src/nn_descent.rs | 191 +++++++++++++++++++++++---------- crates/famst/src/union_find.rs | 24 ++--- 3 files changed, 236 insertions(+), 89 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 4e76bfe..10d8a50 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -19,19 +19,84 @@ use rand::seq::SliceRandom; use rand::Rng; use union_find::{find_components, UnionFind}; -/// Node index type (32-bit for memory efficiency, limits graphs to 2^32 nodes) -pub type NodeId = u32; +/// Node index with embedded "new" flag in the least significant bit. +/// +/// The LSB (bit 0) is used as a "new" flag for NN-Descent optimization. +/// The actual node index is stored in bits 1-31, giving capacity for 2^31 nodes. +/// +/// Using LSB means that during sorting, the same node ID with or without the +/// new flag will be adjacent (differ only by 1), which helps with existence checks. +#[derive(Debug, Clone, Copy, Default, Hash)] +pub struct NodeId(u32); + +impl NodeId { + /// Flag bit indicating a "new" neighbor (LSB) + const NEW_FLAG: u32 = 1; + + /// Create a new NodeId from a raw index (not marked as new) + #[inline] + pub fn new(index: u32) -> Self { + debug_assert!(index < (1 << 31), "NodeId index must be < 2^31"); + NodeId(index << 1) + } + + /// Get the actual node index (without the new flag) + #[inline] + pub fn index(self) -> u32 { + self.0 >> 1 + } + + /// Check if this NodeId has the "new" flag set + #[inline] + pub fn is_new(self) -> bool { + self.0 & Self::NEW_FLAG != 0 + } + + /// Return a copy with the "new" flag set + #[inline] + pub fn as_new(self) -> Self { + NodeId(self.0 | Self::NEW_FLAG) + } + + /// Return a copy with the "new" flag cleared + #[inline] + pub fn as_old(self) -> Self { + NodeId(self.0 & !Self::NEW_FLAG) + } +} + +impl PartialEq for NodeId { + fn eq(&self, other: &Self) -> bool { + // Compare by actual index, ignoring new flag + self.index() == other.index() + } +} + +impl Eq for NodeId {} + +impl PartialOrd for NodeId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for NodeId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Compare by actual index, ignoring new flag + self.index().cmp(&other.index()) + } +} /// An edge in the MST, represented as (node_a, node_b, distance) #[derive(Debug, Clone)] pub struct Edge { - pub u: NodeId, - pub v: NodeId, + pub u: u32, + pub v: u32, pub distance: f32, } impl Edge { - fn new(u: NodeId, v: NodeId, distance: f32) -> Self { + fn new(u: u32, v: u32, distance: f32) -> Self { Edge { u, v, distance } } } @@ -45,6 +110,7 @@ pub(crate) struct Neighbor { impl PartialEq for Neighbor { fn eq(&self, other: &Self) -> bool { + // Compare by (distance, index) - index comparison ignores new flag self.distance == other.distance && self.index == other.index } } @@ -59,7 +125,7 @@ impl PartialOrd for Neighbor { impl Ord for Neighbor { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Total ordering by (distance, index) + // Total ordering by (distance, index) - index comparison ignores new flag self.distance .partial_cmp(&other.distance) .unwrap_or(std::cmp::Ordering::Equal) @@ -180,8 +246,8 @@ where { let n = data.len(); assert!( - n <= NodeId::MAX as usize, - "famst: data length {n} exceeds maximum supported size of 2^32" + n <= (1 << 31), + "famst: data length {n} exceeds maximum supported size of 2^31" ); if n <= 1 { @@ -256,7 +322,7 @@ fn build_undirected_graph(ann_graph: &AnnGraph) -> Vec> { for neighbor in ann_graph.neighbors(i) { let j = neighbor.index; graph[i].push(j); - graph[j as usize].push(i as NodeId); + graph[j.index() as usize].push(NodeId::new(i as u32)); } } // Sort and deduplicate each adjacency list @@ -294,8 +360,8 @@ where for _ in 0..lambda_sq { let u = *components[i].choose(rng).unwrap(); let v = *components[j].choose(rng).unwrap(); - let d = distance_fn(&data[u as usize], &data[v as usize]); - candidates.push(Edge::new(u, v, d)); + let d = distance_fn(&data[u.index() as usize], &data[v.index() as usize]); + candidates.push(Edge::new(u.index(), v.index(), d)); } // Sort by distance and take top λ @@ -329,7 +395,7 @@ where let mut node_to_component: Vec = vec![0; n]; for (comp_idx, component) in components.iter().enumerate() { for &node in component { - node_to_component[node as usize] = comp_idx; + node_to_component[node.index() as usize] = comp_idx; } } @@ -344,11 +410,12 @@ where // Get neighbors of u that are in component ci // (undirected_graph is sorted, so we can iterate directly) for &u_prime in &undirected_graph[edge.u as usize] { - let u_prime_usize = u_prime as usize; - if node_to_component[u_prime_usize] == ci && u_prime != edge.v { + let u_prime_idx = u_prime.index(); + let u_prime_usize = u_prime_idx as usize; + if node_to_component[u_prime_usize] == ci && u_prime_idx != edge.v { let d_prime = distance_fn(&data[u_prime_usize], &data[best_v as usize]); if d_prime < best_d { - best_u = u_prime; + best_u = u_prime_idx; best_d = d_prime; } } @@ -356,11 +423,12 @@ where // Get neighbors of v that are in component cj for &v_prime in &undirected_graph[edge.v as usize] { - let v_prime_usize = v_prime as usize; - if node_to_component[v_prime_usize] == cj && v_prime != edge.u { + let v_prime_idx = v_prime.index(); + let v_prime_usize = v_prime_idx as usize; + if node_to_component[v_prime_usize] == cj && v_prime_idx != edge.u { let d_prime = distance_fn(&data[best_u as usize], &data[v_prime_usize]); if d_prime < best_d { - best_v = v_prime; + best_v = v_prime_idx; best_d = d_prime; } } @@ -383,7 +451,7 @@ fn extract_mst(ann_graph: &AnnGraph, inter_edges: &[Edge], n: usize) -> Vec Vec> { +/// Build reverse neighbor lists, separating new and old neighbors +/// Returns (old_reverse, new_reverse) +fn build_reverse_lists(graph: &AnnGraph) -> (Vec>, Vec>) { let n = graph.n(); - let mut reverse: Vec> = vec![Vec::new(); n]; + let mut old_reverse: Vec> = vec![Vec::new(); n]; + let mut new_reverse: Vec> = vec![Vec::new(); n]; + for i in 0..n { + let i_id = NodeId::new(i as u32); for neighbor in graph.neighbors(i) { - reverse[neighbor.index as usize].push(i as NodeId); + let target = neighbor.index.index() as usize; + if neighbor.index.is_new() { + new_reverse[target].push(i_id); + } else { + old_reverse[target].push(i_id); + } } } - reverse + (old_reverse, new_reverse) } /// Initialize ANN graph with random neighbors @@ -40,7 +52,7 @@ where .flat_map(|i| { let mut local_rng = SmallRng::seed_from_u64(seeds[i]); let mut neighbors: Vec = Vec::with_capacity(k); - let mut seen: HashSet = HashSet::with_capacity(k); + let mut seen: HashSet = HashSet::with_capacity(k); // Sample k random neighbors using Floyd's algorithm - guaranteed O(k) let effective_n = n - 1; // exclude self @@ -48,20 +60,21 @@ where for t in range_start..effective_n { let j = local_rng.gen_range(0..=t); // Map j to actual index, skipping i - let actual_j = (if j >= i { j + 1 } else { j }) as NodeId; + let actual_j = (if j >= i { j + 1 } else { j }) as u32; let selected = if seen.insert(actual_j) { actual_j } else { // j was already selected, so add t instead - let actual_t = (if t >= i { t + 1 } else { t }) as NodeId; + let actual_t = (if t >= i { t + 1 } else { t }) as u32; seen.insert(actual_t); actual_t }; let d = distance_fn(&data[i], &data[selected as usize]); + // Mark as new (not yet used for candidate generation) neighbors.push(Neighbor { - index: selected, + index: NodeId::new(selected).as_new(), distance: d, }); } @@ -78,13 +91,16 @@ where /// Try to insert a new neighbor into a sorted neighbor slice. /// Returns true if the neighbor was inserted (better than the worst). /// Assumes neighbors are sorted by (distance, index) for total ordering. +/// The new flag (LSB) doesn't affect ordering since NodeId::cmp ignores it. fn insert_neighbor(neighbors: &mut [Neighbor], new_index: NodeId, new_distance: f32) -> bool { + // Create a search key with the new flag set (new insertions are always "new") let new_neighbor = Neighbor { - index: new_index, + index: new_index.as_new(), distance: new_distance, }; - // Binary search using total ordering - also serves as existence check + // Binary search by (distance, index) - NodeId comparison ignores new flag + // so this will find the node regardless of its new/old status match neighbors.binary_search(&new_neighbor) { Ok(_) => false, // Already exists Err(insert_pos) => { @@ -108,6 +124,8 @@ fn insert_neighbor(neighbors: &mut [Neighbor], new_index: NodeId, new_distance: /// /// Based on: "Efficient K-Nearest Neighbor Graph Construction for Generic Similarity Measures" /// by Wei Dong, Charikar Moses, and Kai Li (2011) +/// +/// Uses the "new/old" optimization: only compare pairs where at least one neighbor is "new". pub(crate) fn nn_descent( data: &[T], distance_fn: &D, @@ -126,77 +144,138 @@ where return AnnGraph::new(n, 0, vec![]); } - // Initialize ANN graph with random neighbors + // Initialize ANN graph with random neighbors (all marked as "new") println!("Initializing random graph"); let mut graph = init_random_graph(data, k, distance_fn, rng); // NN-Descent iterations for iter in 0..config.nn_descent_iterations { - println!("NN-Descent iteration {iter}"); + // Build reverse neighbor lists, separating old and new + let (old_reverse, new_reverse) = build_reverse_lists(&graph); + + // For each point, collect old and new forward neighbors + let mut old_neighbors: Vec> = vec![Vec::new(); n]; + let mut new_neighbors: Vec> = vec![Vec::new(); n]; + + for i in 0..n { + for nb in graph.neighbors(i) { + let idx = nb.index.as_old(); // Strip new flag for storage + if nb.index.is_new() { + new_neighbors[i].push(idx); + } else { + old_neighbors[i].push(idx); + } + } + } + + // Mark all neighbors as old for next iteration + for i in 0..n { + for nb in graph.neighbors_mut(i) { + nb.index = nb.index.as_old(); + } + } + let mut updates = 0; - let reverse_neighbors = build_reverse(&graph); - // For each point, explore neighbors of neighbors + // For each point, generate candidates from neighbors of neighbors + // Key optimization: only consider pairs where at least one is "new" for i in 0..n { + let i_id = NodeId::new(i as u32); + + // Combine forward and reverse neighbors + let old_i: Vec = old_neighbors[i] + .iter() + .chain(old_reverse[i].iter()) + .copied() + .collect(); + let new_i: Vec = new_neighbors[i] + .iter() + .chain(new_reverse[i].iter()) + .copied() + .collect(); + + // Skip if no new neighbors + if new_i.is_empty() { + continue; + } + // Build set of current neighbors for O(1) lookup - let current_neighbors: HashSet = - graph.neighbors(i).iter().map(|nb| nb.index).collect(); - - // Collect candidates: neighbors and reverse neighbors - let mut candidates: Vec = Vec::new(); - - // Sample from forward neighbors - let mut sampled_forward: Vec = - graph.neighbors(i).iter().map(|nb| nb.index).collect(); - let sample_size = - ((sampled_forward.len() as f64 * config.nn_descent_sample_rate).ceil() as usize) - .max(1); - sampled_forward.shuffle(rng); - sampled_forward.truncate(sample_size); - - // Sample from reverse neighbors - let mut sampled_reverse = reverse_neighbors[i].clone(); - let sample_size = - ((sampled_reverse.len() as f64 * config.nn_descent_sample_rate).ceil() as usize) - .max(1); - sampled_reverse.shuffle(rng); - sampled_reverse.truncate(sample_size); - - // Neighbors of neighbors - let i_id = i as NodeId; - for &neighbor in sampled_forward.iter().chain(sampled_reverse.iter()) { - for nb in graph.neighbors(neighbor as usize) { - if nb.index != i_id && !current_neighbors.contains(&nb.index) { - candidates.push(nb.index); + let current_neighbors: HashSet = graph + .neighbors(i) + .iter() + .map(|nb| nb.index.as_old()) + .collect(); + + let mut candidates: HashSet = HashSet::new(); + + // new-new pairs: for each new neighbor, look at their new neighbors + for &u in &new_i { + let u_idx = u.index() as usize; + for &v in &new_neighbors[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); } } - // Also check reverse neighbors of neighbors - for &rn in &reverse_neighbors[neighbor as usize] { - if rn != i_id && !current_neighbors.contains(&rn) { - candidates.push(rn); + for &v in &new_reverse[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); } } } - // Deduplicate candidates - candidates.sort_unstable(); - candidates.dedup(); + // new-old pairs: for each new neighbor, look at their old neighbors + for &u in &new_i { + let u_idx = u.index() as usize; + for &v in &old_neighbors[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); + } + } + for &v in &old_reverse[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); + } + } + } - // Try to improve neighbors - for c in candidates { - let d = distance_fn(&data[i], &data[c as usize]); + // old-new pairs: for each old neighbor, look at their new neighbors + for &u in &old_i { + let u_idx = u.index() as usize; + for &v in &new_neighbors[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); + } + } + for &v in &new_reverse[u_idx] { + if v != i_id && !current_neighbors.contains(&v) { + candidates.insert(v); + } + } + } + // Try to improve neighbors with candidates + for c in candidates { + let d = distance_fn(&data[i], &data[c.index() as usize]); if insert_neighbor(graph.neighbors_mut(i), c, d) { updates += 1; } } } + println!("NN-Descent iteration {iter}: {updates} updates"); + // Early termination if no updates if updates == 0 { break; } } + // Strip the new flag from all neighbors before returning + for i in 0..n { + for nb in graph.neighbors_mut(i) { + nb.index = nb.index.as_old(); + } + } + graph } diff --git a/crates/famst/src/union_find.rs b/crates/famst/src/union_find.rs index ee5c75d..8e0a3ed 100644 --- a/crates/famst/src/union_find.rs +++ b/crates/famst/src/union_find.rs @@ -4,26 +4,26 @@ use crate::{AnnGraph, NodeId}; /// Union-Find (Disjoint Set Union) data structure for Kruskal's algorithm pub(crate) struct UnionFind { - parent: Vec, - rank: Vec, + parent: Vec, + rank: Vec, } impl UnionFind { pub(crate) fn new(n: usize) -> Self { UnionFind { - parent: (0..n as NodeId).collect(), + parent: (0..n as u32).collect(), rank: vec![0; n], } } - pub(crate) fn find(&mut self, x: NodeId) -> NodeId { + pub(crate) fn find(&mut self, x: u32) -> u32 { if self.parent[x as usize] != x { self.parent[x as usize] = self.find(self.parent[x as usize]); // Path compression } self.parent[x as usize] } - pub(crate) fn union(&mut self, x: NodeId, y: NodeId) -> bool { + pub(crate) fn union(&mut self, x: u32, y: u32) -> bool { let px = self.find(x); let py = self.find(y); if px == py { @@ -52,16 +52,16 @@ pub(crate) fn find_components(ann_graph: &AnnGraph) -> Vec> { // Union all edges (treating directed as undirected) for i in 0..n { for neighbor in ann_graph.neighbors(i) { - uf.union(i as NodeId, neighbor.index); + uf.union(i as u32, neighbor.index.index()); } } // First pass: assign contiguous component IDs to each root - let mut root_to_component: Vec = vec![NodeId::MAX; n]; - let mut num_components = 0; + let mut root_to_component: Vec = vec![u32::MAX; n]; + let mut num_components: u32 = 0; for i in 0..n { - let root = uf.find(i as NodeId) as usize; - if root_to_component[root] == NodeId::MAX { + let root = uf.find(i as u32) as usize; + if root_to_component[root] == u32::MAX { root_to_component[root] = num_components; num_components += 1; } @@ -70,9 +70,9 @@ pub(crate) fn find_components(ann_graph: &AnnGraph) -> Vec> { // Second pass: fill components directly let mut components: Vec> = vec![Vec::new(); num_components as usize]; for i in 0..n { - let root = uf.find(i as NodeId) as usize; + let root = uf.find(i as u32) as usize; let component_id = root_to_component[root] as usize; - components[component_id].push(i as NodeId); + components[component_id].push(NodeId::new(i as u32)); } components From a8af2c65c83cf4402dff3d142a1924bde28d596f Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 09:50:30 +0100 Subject: [PATCH 22/29] Update lib.rs --- crates/famst/src/lib.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 10d8a50..82ae640 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -181,8 +181,6 @@ pub struct FamstConfig { pub max_iterations: usize, /// Maximum NN-Descent iterations pub nn_descent_iterations: usize, - /// Sample rate for NN-Descent (fraction of neighbors to sample) - pub nn_descent_sample_rate: f64, } impl Default for FamstConfig { @@ -192,7 +190,6 @@ impl Default for FamstConfig { lambda: 5, max_iterations: 100, nn_descent_iterations: 10, - nn_descent_sample_rate: 0.5, } } } @@ -715,7 +712,6 @@ mod tests { lambda: 5, max_iterations: 50, nn_descent_iterations: 20, - nn_descent_sample_rate: 1.0, // Full sampling for small dataset }; let mut famst_rng = StdRng::seed_from_u64(88888); @@ -817,7 +813,6 @@ mod tests { lambda: 5, max_iterations: 100, nn_descent_iterations: 10, - nn_descent_sample_rate: 0.5, }; let mut famst_rng = StdRng::seed_from_u64(54321); let start = std::time::Instant::now(); @@ -856,7 +851,6 @@ mod tests { lambda: 5, max_iterations: 100, nn_descent_iterations: 100, - nn_descent_sample_rate: 0.5, }; let mut famst_rng = StdRng::seed_from_u64(11111); let result = famst_with_rng(&points, distance, &config, &mut famst_rng); From d4e68c73a8aa6ceb0d2340374ea9fa864fec9abb Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 11:39:55 +0100 Subject: [PATCH 23/29] rewrite reverse edge computation --- crates/famst/src/nn_descent.rs | 91 ++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 20 deletions(-) diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index f552144..9a8f3dd 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -14,21 +14,72 @@ use std::collections::HashSet; use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; -/// Build reverse neighbor lists, separating new and old neighbors -/// Returns (old_reverse, new_reverse) -fn build_reverse_lists(graph: &AnnGraph) -> (Vec>, Vec>) { +/// Reverse neighbor lists with bounded size k per node. +/// Uses flat storage: data[i*k..(i+1)*k] contains up to k reverse neighbors of node i. +/// Uses reservoir sampling when more than k reverse edges exist. +struct ReverseNeighbors { + /// Flat storage of reverse neighbor IDs (with new flag preserved) + data: Vec, + /// Count of reverse neighbors seen so far (for reservoir sampling) + counts: Vec, + /// Max reverse neighbors per node + k: usize, +} + +impl ReverseNeighbors { + fn new(n: usize, k: usize) -> Self { + ReverseNeighbors { + data: vec![NodeId::new(0); n * k], + counts: vec![0; n], + k, + } + } + + /// Add a reverse edge: node `from` is a neighbor of node `to`, so `to` has reverse edge to `from`. + /// Uses reservoir sampling to maintain at most k reverse neighbors. + #[inline] + fn add(&mut self, to: usize, from: NodeId, rng: &mut impl Rng) { + let count = self.counts[to] as usize; + let start = to * self.k; + + if count < self.k { + // Still have room, just append + self.data[start + count] = from; + } else { + // Reservoir sampling: replace with probability k / (count + 1) + let j = rng.gen_range(0..=count); + if j < self.k { + self.data[start + j] = from; + } + } + self.counts[to] += 1; + } + + /// Get the reverse neighbors of node i (only the filled slots) + #[inline] + fn get(&self, i: usize) -> &[NodeId] { + let start = i * self.k; + let count = (self.counts[i] as usize).min(self.k); + &self.data[start..start + count] + } +} + +/// Build reverse neighbor lists with reservoir sampling. +/// Returns separate old and new reverse neighbor structures. +fn build_reverse_lists(graph: &AnnGraph, rng: &mut impl Rng) -> (ReverseNeighbors, ReverseNeighbors) { let n = graph.n(); - let mut old_reverse: Vec> = vec![Vec::new(); n]; - let mut new_reverse: Vec> = vec![Vec::new(); n]; + let k = graph.k(); + let mut old_reverse = ReverseNeighbors::new(n, k); + let mut new_reverse = ReverseNeighbors::new(n, k); for i in 0..n { let i_id = NodeId::new(i as u32); for neighbor in graph.neighbors(i) { let target = neighbor.index.index() as usize; if neighbor.index.is_new() { - new_reverse[target].push(i_id); + new_reverse.add(target, i_id, rng); } else { - old_reverse[target].push(i_id); + old_reverse.add(target, i_id, rng); } } } @@ -150,8 +201,8 @@ where // NN-Descent iterations for iter in 0..config.nn_descent_iterations { - // Build reverse neighbor lists, separating old and new - let (old_reverse, new_reverse) = build_reverse_lists(&graph); + // Build reverse neighbor lists, separating old and new (with reservoir sampling) + let (old_reverse, new_reverse) = build_reverse_lists(&graph, rng); // For each point, collect old and new forward neighbors let mut old_neighbors: Vec> = vec![Vec::new(); n]; @@ -185,12 +236,12 @@ where // Combine forward and reverse neighbors let old_i: Vec = old_neighbors[i] .iter() - .chain(old_reverse[i].iter()) + .chain(old_reverse.get(i).iter()) .copied() .collect(); let new_i: Vec = new_neighbors[i] .iter() - .chain(new_reverse[i].iter()) + .chain(new_reverse.get(i).iter()) .copied() .collect(); @@ -216,9 +267,9 @@ where candidates.insert(v); } } - for &v in &new_reverse[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); + for v in new_reverse.get(u_idx) { + if *v != i_id && !current_neighbors.contains(v) { + candidates.insert(*v); } } } @@ -231,9 +282,9 @@ where candidates.insert(v); } } - for &v in &old_reverse[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); + for v in old_reverse.get(u_idx) { + if *v != i_id && !current_neighbors.contains(v) { + candidates.insert(*v); } } } @@ -246,9 +297,9 @@ where candidates.insert(v); } } - for &v in &new_reverse[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); + for v in new_reverse.get(u_idx) { + if *v != i_id && !current_neighbors.contains(v) { + candidates.insert(*v); } } } From 5477533aea61203e5e43e94bc3c6df2de909e2dd Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 12:06:50 +0100 Subject: [PATCH 24/29] Update nn_descent.rs --- crates/famst/src/nn_descent.rs | 154 ++++++++++++++------------------- 1 file changed, 64 insertions(+), 90 deletions(-) diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 9a8f3dd..2b3b703 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -14,48 +14,54 @@ use std::collections::HashSet; use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; -/// Reverse neighbor lists with bounded size k per node. -/// Uses flat storage: data[i*k..(i+1)*k] contains up to k reverse neighbors of node i. -/// Uses reservoir sampling when more than k reverse edges exist. -struct ReverseNeighbors { - /// Flat storage of reverse neighbor IDs (with new flag preserved) +/// Neighbor lists with bounded size k per node. +/// Uses flat storage: data[i*k..(i+1)*k] contains up to k neighbors of node i. +/// Uses reservoir sampling when more than k neighbors exist. +struct Neighbors { + /// Flat storage of neighbor IDs data: Vec, - /// Count of reverse neighbors seen so far (for reservoir sampling) + /// Count of neighbors seen so far (for reservoir sampling) counts: Vec, - /// Max reverse neighbors per node + /// Max neighbors per node k: usize, } -impl ReverseNeighbors { +impl Neighbors { fn new(n: usize, k: usize) -> Self { - ReverseNeighbors { + Neighbors { data: vec![NodeId::new(0); n * k], counts: vec![0; n], k, } } - /// Add a reverse edge: node `from` is a neighbor of node `to`, so `to` has reverse edge to `from`. - /// Uses reservoir sampling to maintain at most k reverse neighbors. + /// Add a neighbor using reservoir sampling to maintain at most k neighbors. + /// Skips if the neighbor is already present. #[inline] - fn add(&mut self, to: usize, from: NodeId, rng: &mut impl Rng) { - let count = self.counts[to] as usize; - let start = to * self.k; + fn add(&mut self, node: usize, neighbor: NodeId, rng: &mut impl Rng) { + let count = self.counts[node] as usize; + let start = node * self.k; + let filled = count.min(self.k); + + // Check if neighbor already exists in the filled portion + if self.data[start..start + filled].contains(&neighbor) { + return; + } if count < self.k { // Still have room, just append - self.data[start + count] = from; + self.data[start + count] = neighbor; } else { // Reservoir sampling: replace with probability k / (count + 1) let j = rng.gen_range(0..=count); if j < self.k { - self.data[start + j] = from; + self.data[start + j] = neighbor; } } - self.counts[to] += 1; + self.counts[node] += 1; } - /// Get the reverse neighbors of node i (only the filled slots) + /// Get the neighbors of node i (only the filled slots) #[inline] fn get(&self, i: usize) -> &[NodeId] { let start = i * self.k; @@ -64,26 +70,46 @@ impl ReverseNeighbors { } } -/// Build reverse neighbor lists with reservoir sampling. -/// Returns separate old and new reverse neighbor structures. -fn build_reverse_lists(graph: &AnnGraph, rng: &mut impl Rng) -> (ReverseNeighbors, ReverseNeighbors) { +/// Build combined neighbor lists (forward + reverse) with reservoir sampling. +/// Returns (old_neighbors, new_neighbors), each with 2*k capacity per node. +/// Only marks neighbors that were selected into new_neighbors as old. +fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors, Neighbors) { let n = graph.n(); let k = graph.k(); - let mut old_reverse = ReverseNeighbors::new(n, k); - let mut new_reverse = ReverseNeighbors::new(n, k); + // 2*k capacity: k for forward + k for reverse + let mut old_neighbors = Neighbors::new(n, 2 * k); + let mut new_neighbors = Neighbors::new(n, 2 * k); for i in 0..n { let i_id = NodeId::new(i as u32); for neighbor in graph.neighbors(i) { let target = neighbor.index.index() as usize; + let target_id = neighbor.index.as_old(); // Strip new flag for storage if neighbor.index.is_new() { - new_reverse.add(target, i_id, rng); + // Forward: i -> target, Reverse: target <- i + new_neighbors.add(i, target_id, rng); + new_neighbors.add(target, i_id, rng); } else { - old_reverse.add(target, i_id, rng); + old_neighbors.add(i, target_id, rng); + old_neighbors.add(target, i_id, rng); } } } - (old_reverse, new_reverse) + + // Only mark neighbors as old if they were selected into new_neighbors + for i in 0..n { + for &selected_id in new_neighbors.get(i) { + // Find this neighbor in the graph and mark as old if it's still new + for nb in graph.neighbors_mut(i) { + if nb.index == selected_id { + nb.index = nb.index.as_old(); + break; + } + } + } + } + + (old_neighbors, new_neighbors) } /// Initialize ANN graph with random neighbors @@ -201,30 +227,10 @@ where // NN-Descent iterations for iter in 0..config.nn_descent_iterations { - // Build reverse neighbor lists, separating old and new (with reservoir sampling) - let (old_reverse, new_reverse) = build_reverse_lists(&graph, rng); - - // For each point, collect old and new forward neighbors - let mut old_neighbors: Vec> = vec![Vec::new(); n]; - let mut new_neighbors: Vec> = vec![Vec::new(); n]; - - for i in 0..n { - for nb in graph.neighbors(i) { - let idx = nb.index.as_old(); // Strip new flag for storage - if nb.index.is_new() { - new_neighbors[i].push(idx); - } else { - old_neighbors[i].push(idx); - } - } - } - - // Mark all neighbors as old for next iteration - for i in 0..n { - for nb in graph.neighbors_mut(i) { - nb.index = nb.index.as_old(); - } - } + println!("NN-Descent iteration {iter}..."); + // Build combined neighbor lists (forward + reverse, with reservoir sampling) + // Also marks all neighbors as old for next iteration + let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng); let mut updates = 0; @@ -233,17 +239,8 @@ where for i in 0..n { let i_id = NodeId::new(i as u32); - // Combine forward and reverse neighbors - let old_i: Vec = old_neighbors[i] - .iter() - .chain(old_reverse.get(i).iter()) - .copied() - .collect(); - let new_i: Vec = new_neighbors[i] - .iter() - .chain(new_reverse.get(i).iter()) - .copied() - .collect(); + let old_i = old_neighbors.get(i); + let new_i = new_neighbors.get(i); // Skip if no new neighbors if new_i.is_empty() { @@ -260,14 +257,9 @@ where let mut candidates: HashSet = HashSet::new(); // new-new pairs: for each new neighbor, look at their new neighbors - for &u in &new_i { + for &u in new_i { let u_idx = u.index() as usize; - for &v in &new_neighbors[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); - } - } - for v in new_reverse.get(u_idx) { + for v in new_neighbors.get(u_idx) { if *v != i_id && !current_neighbors.contains(v) { candidates.insert(*v); } @@ -275,14 +267,9 @@ where } // new-old pairs: for each new neighbor, look at their old neighbors - for &u in &new_i { + for &u in new_i { let u_idx = u.index() as usize; - for &v in &old_neighbors[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); - } - } - for v in old_reverse.get(u_idx) { + for v in old_neighbors.get(u_idx) { if *v != i_id && !current_neighbors.contains(v) { candidates.insert(*v); } @@ -290,14 +277,9 @@ where } // old-new pairs: for each old neighbor, look at their new neighbors - for &u in &old_i { + for &u in old_i { let u_idx = u.index() as usize; - for &v in &new_neighbors[u_idx] { - if v != i_id && !current_neighbors.contains(&v) { - candidates.insert(v); - } - } - for v in new_reverse.get(u_idx) { + for v in new_neighbors.get(u_idx) { if *v != i_id && !current_neighbors.contains(v) { candidates.insert(*v); } @@ -320,13 +302,5 @@ where break; } } - - // Strip the new flag from all neighbors before returning - for i in 0..n { - for nb in graph.neighbors_mut(i) { - nb.index = nb.index.as_old(); - } - } - graph } From 34b6e1fe16286aad9f9eb9290a455edaaf3b1fc2 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 12:44:42 +0100 Subject: [PATCH 25/29] Use more cache-friendly loop --- crates/famst/src/lib.rs | 7 ++ crates/famst/src/nn_descent.rs | 139 +++++++++++++++------------------ 2 files changed, 68 insertions(+), 78 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 82ae640..1302d70 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -17,6 +17,8 @@ mod union_find; use nn_descent::nn_descent; use rand::seq::SliceRandom; use rand::Rng; +use rayon::iter::IndexedParallelIterator; +use rayon::slice::ParallelSliceMut; use union_find::{find_components, UnionFind}; /// Node index with embedded "new" flag in the least significant bit. @@ -169,6 +171,11 @@ impl AnnGraph { let start = i * self.k; &mut self.data[start..start + self.k] } + + /// Get mutable access to all neighbor chunks for parallel processing + pub(crate) fn neighbors_chunks_mut(&mut self) -> impl IndexedParallelIterator { + self.data.par_chunks_mut(self.k) + } } /// FAMST algorithm configuration diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 2b3b703..52eb22d 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -19,7 +19,7 @@ use crate::{AnnGraph, FamstConfig, Neighbor, NodeId}; /// Uses reservoir sampling when more than k neighbors exist. struct Neighbors { /// Flat storage of neighbor IDs - data: Vec, + data: Vec, /// Count of neighbors seen so far (for reservoir sampling) counts: Vec, /// Max neighbors per node @@ -29,7 +29,7 @@ struct Neighbors { impl Neighbors { fn new(n: usize, k: usize) -> Self { Neighbors { - data: vec![NodeId::new(0); n * k], + data: vec![0; n * k], counts: vec![0; n], k, } @@ -38,7 +38,7 @@ impl Neighbors { /// Add a neighbor using reservoir sampling to maintain at most k neighbors. /// Skips if the neighbor is already present. #[inline] - fn add(&mut self, node: usize, neighbor: NodeId, rng: &mut impl Rng) { + fn add(&mut self, node: usize, neighbor: u32, rng: &mut impl Rng) { let count = self.counts[node] as usize; let start = node * self.k; let filled = count.min(self.k); @@ -63,7 +63,7 @@ impl Neighbors { /// Get the neighbors of node i (only the filled slots) #[inline] - fn get(&self, i: usize) -> &[NodeId] { + fn get(&self, i: usize) -> &[u32] { let start = i * self.k; let count = (self.counts[i] as usize).min(self.k); &self.data[start..start + count] @@ -77,37 +77,38 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors, let n = graph.n(); let k = graph.k(); // 2*k capacity: k for forward + k for reverse - let mut old_neighbors = Neighbors::new(n, 2 * k); - let mut new_neighbors = Neighbors::new(n, 2 * k); + let mut old_neighbors = Neighbors::new(n, k * 2); + let mut new_neighbors = Neighbors::new(n, k * 2); for i in 0..n { - let i_id = NodeId::new(i as u32); for neighbor in graph.neighbors(i) { - let target = neighbor.index.index() as usize; - let target_id = neighbor.index.as_old(); // Strip new flag for storage + let target = neighbor.index.index(); if neighbor.index.is_new() { // Forward: i -> target, Reverse: target <- i - new_neighbors.add(i, target_id, rng); - new_neighbors.add(target, i_id, rng); + new_neighbors.add(i, target, rng); + new_neighbors.add(target as usize, i as u32, rng); } else { - old_neighbors.add(i, target_id, rng); - old_neighbors.add(target, i_id, rng); + old_neighbors.add(i, target, rng); + old_neighbors.add(target as usize, i as u32, rng); } } } // Only mark neighbors as old if they were selected into new_neighbors - for i in 0..n { - for &selected_id in new_neighbors.get(i) { - // Find this neighbor in the graph and mark as old if it's still new - for nb in graph.neighbors_mut(i) { - if nb.index == selected_id { - nb.index = nb.index.as_old(); - break; + graph + .neighbors_chunks_mut() + .enumerate() + .for_each(|(i, neighbors)| { + for &selected_id in new_neighbors.get(i) { + // Find this neighbor in the graph and mark as old if it's still new + for nb in neighbors.iter_mut() { + if nb.index.index() == selected_id { + nb.index = nb.index.as_old(); + break; + } } } - } - } + }); (old_neighbors, new_neighbors) } @@ -232,66 +233,48 @@ where // Also marks all neighbors as old for next iteration let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng); - let mut updates = 0; - // For each point, generate candidates from neighbors of neighbors // Key optimization: only consider pairs where at least one is "new" - for i in 0..n { - let i_id = NodeId::new(i as u32); - - let old_i = old_neighbors.get(i); - let new_i = new_neighbors.get(i); - - // Skip if no new neighbors - if new_i.is_empty() { - continue; - } - - // Build set of current neighbors for O(1) lookup - let current_neighbors: HashSet = graph - .neighbors(i) - .iter() - .map(|nb| nb.index.as_old()) - .collect(); - - let mut candidates: HashSet = HashSet::new(); - - // new-new pairs: for each new neighbor, look at their new neighbors - for &u in new_i { - let u_idx = u.index() as usize; - for v in new_neighbors.get(u_idx) { - if *v != i_id && !current_neighbors.contains(v) { - candidates.insert(*v); + let candidates: HashSet<(u32, u32)> = (0..n) + .into_par_iter() + .fold( + HashSet::new, + |mut local_candidates, i| { + let old_i = old_neighbors.get(i); + let new_i = new_neighbors.get(i); + + // Skip if no new neighbors + if !new_i.is_empty() { + for &u in new_i { + for &v in new_i { + if u < v { + local_candidates.insert((u, v)); + } + } + for &v in old_i { + if u != v { + local_candidates.insert((u.min(v), u.max(v))); + } + } + } } - } - } - - // new-old pairs: for each new neighbor, look at their old neighbors - for &u in new_i { - let u_idx = u.index() as usize; - for v in old_neighbors.get(u_idx) { - if *v != i_id && !current_neighbors.contains(v) { - candidates.insert(*v); - } - } - } - - // old-new pairs: for each old neighbor, look at their new neighbors - for &u in old_i { - let u_idx = u.index() as usize; - for v in new_neighbors.get(u_idx) { - if *v != i_id && !current_neighbors.contains(v) { - candidates.insert(*v); - } - } + local_candidates + }, + ) + .reduce(HashSet::new, |mut a, b| { + a.extend(b); + a + }); + + // Try to improve neighbors with candidates + let mut updates = 0; + for &(u, v) in &candidates { + let d = distance_fn(&data[u as usize], &data[v as usize]); + if insert_neighbor(graph.neighbors_mut(u as usize), NodeId::new(v), d) { + updates += 1; } - - // Try to improve neighbors with candidates - for c in candidates { - let d = distance_fn(&data[i], &data[c.index() as usize]); - if insert_neighbor(graph.neighbors_mut(i), c, d) { - updates += 1; - } + if insert_neighbor(graph.neighbors_mut(v as usize), NodeId::new(u), d) { + updates += 1; } } From ca7b07494777fccb61d0a597385c1bee95b9217e Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 14:23:26 +0100 Subject: [PATCH 26/29] slightly faster parallel version (but it seem like a single threaded implementation is actually faster?!?) --- crates/famst/src/nn_descent.rs | 56 ++++++++++++++++------------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 52eb22d..9496dc8 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -235,36 +235,34 @@ where // For each point, generate candidates from neighbors of neighbors // Key optimization: only consider pairs where at least one is "new" - let candidates: HashSet<(u32, u32)> = (0..n) + let mut candidates: Vec<(u32, u32)> = (0..n) .into_par_iter() - .fold( - HashSet::new, - |mut local_candidates, i| { - let old_i = old_neighbors.get(i); - let new_i = new_neighbors.get(i); - - // Skip if no new neighbors - if !new_i.is_empty() { - for &u in new_i { - for &v in new_i { - if u < v { - local_candidates.insert((u, v)); - } - } - for &v in old_i { - if u != v { - local_candidates.insert((u.min(v), u.max(v))); - } - } + .flat_map_iter(|i| { + let old_i = old_neighbors.get(i); + let new_i = new_neighbors.get(i); + + // new-new pairs: (u, v) where u < v + let new_new = new_i.iter().flat_map(|&u| { + new_i.iter().filter_map(move |&v| if u < v { Some((u, v)) } else { None }) + }); + + // new-old pairs: (min, max) where u != v + let new_old = new_i.iter().flat_map(|&u| { + old_i.iter().filter_map(move |&v| { + if u != v { + Some((u.min(v), u.max(v))) + } else { + None } - } - local_candidates - }, - ) - .reduce(HashSet::new, |mut a, b| { - a.extend(b); - a - }); + }) + }); + + new_new.chain(new_old) + }) + .take_any(n * k) + .collect(); + candidates.par_sort_unstable(); + candidates.dedup(); // Try to improve neighbors with candidates let mut updates = 0; @@ -278,7 +276,7 @@ where } } - println!("NN-Descent iteration {iter}: {updates} updates"); + println!("NN-Descent iteration {iter}: {updates} updates of {} candidates", candidates.len()); // Early termination if no updates if updates == 0 { From c8c1689cddc81f597754abe7fe0849bf8c539271 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 15:11:26 +0100 Subject: [PATCH 27/29] fastest version so far --- crates/famst/src/lib.rs | 4 +-- crates/famst/src/nn_descent.rs | 60 ++++++++++++++++++++++++---------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index 1302d70..b55ae1a 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -173,8 +173,8 @@ impl AnnGraph { } /// Get mutable access to all neighbor chunks for parallel processing - pub(crate) fn neighbors_chunks_mut(&mut self) -> impl IndexedParallelIterator { - self.data.par_chunks_mut(self.k) + pub(crate) fn neighbors_chunks_mut(&mut self, group_size: usize) -> impl IndexedParallelIterator { + self.data.par_chunks_mut(self.k * group_size) } } diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 9496dc8..1fa9f52 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -96,7 +96,7 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors, // Only mark neighbors as old if they were selected into new_neighbors graph - .neighbors_chunks_mut() + .neighbors_chunks_mut(1) .enumerate() .for_each(|(i, neighbors)| { for &selected_id in new_neighbors.get(i) { @@ -235,7 +235,8 @@ where // For each point, generate candidates from neighbors of neighbors // Key optimization: only consider pairs where at least one is "new" - let mut candidates: Vec<(u32, u32)> = (0..n) + // Collect all candidates with their distances, including reverse edges + let mut candidates: Vec<(u32, u32, f32)> = (0..n) .into_par_iter() .flat_map_iter(|i| { let old_i = old_neighbors.get(i); @@ -243,7 +244,9 @@ where // new-new pairs: (u, v) where u < v let new_new = new_i.iter().flat_map(|&u| { - new_i.iter().filter_map(move |&v| if u < v { Some((u, v)) } else { None }) + new_i + .iter() + .filter_map(move |&v| if u < v { Some((u, v)) } else { None }) }); // new-old pairs: (min, max) where u != v @@ -257,26 +260,47 @@ where }) }); - new_new.chain(new_old) + new_new + .chain(new_old) + .flat_map(|(u, v)| { + let d = distance_fn(&data[u as usize], &data[v as usize]); + // Insert both (u, v) and (v, u) so we can parallelize by node + [(u, v, d), (v, u, d)] + }) }) - .take_any(n * k) .collect(); - candidates.par_sort_unstable(); + + // Sort by first node so candidates for each node are contiguous + candidates.par_sort_unstable_by_key(|(a, _, _)| *a); candidates.dedup(); - // Try to improve neighbors with candidates - let mut updates = 0; - for &(u, v) in &candidates { - let d = distance_fn(&data[u as usize], &data[v as usize]); - if insert_neighbor(graph.neighbors_mut(u as usize), NodeId::new(v), d) { - updates += 1; - } - if insert_neighbor(graph.neighbors_mut(v as usize), NodeId::new(u), d) { - updates += 1; - } - } + // Process in parallel by chunks of nodes + const CHUNK_SIZE: usize = 64; + let updates: usize = graph + .neighbors_chunks_mut(CHUNK_SIZE) + .enumerate() + .map(|(i, chunk_neighbors)| { + let start_node = (CHUNK_SIZE * i) as u32; + // Binary search to find the range of candidates for this node + let start = candidates.partition_point(|&(u, _, _)| u < start_node); + let mut count = 0; + for &(u, v, d) in &candidates[start..] { + if u >= start_node + CHUNK_SIZE as u32 { + break; + } + let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k]; + if insert_neighbor(neighbors, NodeId::new(v), d) { + count += 1; + } + } + count + }) + .sum(); - println!("NN-Descent iteration {iter}: {updates} updates of {} candidates", candidates.len()); + println!( + "NN-Descent iteration {iter}: {updates} updates of {} candidates", + candidates.len(), + ); // Early termination if no updates if updates == 0 { From 8eb4090132edee56878bb65e6db1c1217cbb552b Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 15:18:04 +0100 Subject: [PATCH 28/29] Update nn_descent.rs --- crates/famst/src/nn_descent.rs | 142 ++++++++++++++++++--------------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/crates/famst/src/nn_descent.rs b/crates/famst/src/nn_descent.rs index 1fa9f52..9a0e0fa 100644 --- a/crates/famst/src/nn_descent.rs +++ b/crates/famst/src/nn_descent.rs @@ -232,78 +232,90 @@ where // Build combined neighbor lists (forward + reverse, with reservoir sampling) // Also marks all neighbors as old for next iteration let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng); - - // For each point, generate candidates from neighbors of neighbors - // Key optimization: only consider pairs where at least one is "new" - // Collect all candidates with their distances, including reverse edges - let mut candidates: Vec<(u32, u32, f32)> = (0..n) - .into_par_iter() - .flat_map_iter(|i| { - let old_i = old_neighbors.get(i); - let new_i = new_neighbors.get(i); - - // new-new pairs: (u, v) where u < v - let new_new = new_i.iter().flat_map(|&u| { - new_i - .iter() - .filter_map(move |&v| if u < v { Some((u, v)) } else { None }) - }); - - // new-old pairs: (min, max) where u != v - let new_old = new_i.iter().flat_map(|&u| { - old_i.iter().filter_map(move |&v| { - if u != v { - Some((u.min(v), u.max(v))) - } else { - None + println!(" Built neighbor lists"); + + // Process in batches of central nodes to limit memory usage + let batch_size: usize = n / k; + let mut total_updates = 0; + let mut total_candidates = 0; + + for batch_start in (0..n).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(n); + + // For each point in this batch, generate candidates from neighbors of neighbors + // Key optimization: only consider pairs where at least one is "new" + // Collect all candidates with their distances, including reverse edges + let mut candidates: Vec<(u32, u32, f32)> = (batch_start..batch_end) + .into_par_iter() + .flat_map_iter(|i| { + let old_i = old_neighbors.get(i); + let new_i = new_neighbors.get(i); + + // new-new pairs: (u, v) where u < v + let new_new = new_i.iter().flat_map(|&u| { + new_i + .iter() + .filter_map(move |&v| if u < v { Some((u, v)) } else { None }) + }); + + // new-old pairs: (min, max) where u != v + let new_old = new_i.iter().flat_map(|&u| { + old_i.iter().filter_map(move |&v| { + if u != v { + Some((u.min(v), u.max(v))) + } else { + None + } + }) + }); + + new_new + .chain(new_old) + .flat_map(|(u, v)| { + let d = distance_fn(&data[u as usize], &data[v as usize]); + // Insert both (u, v) and (v, u) so we can parallelize by node + [(u, v, d), (v, u, d)] + }) + }) + .collect(); + + // Sort by first node so candidates for each node are contiguous + candidates.par_sort_unstable_by_key(|(a, _, _)| *a); + candidates.dedup(); + total_candidates += candidates.len(); + + // Process in parallel by chunks of nodes + const CHUNK_SIZE: usize = 64; + let updates: usize = graph + .neighbors_chunks_mut(CHUNK_SIZE) + .enumerate() + .map(|(i, chunk_neighbors)| { + let start_node = (CHUNK_SIZE * i) as u32; + // Binary search to find the range of candidates for this chunk + let start = candidates.partition_point(|&(u, _, _)| u < start_node); + let mut count = 0; + for &(u, v, d) in &candidates[start..] { + if u >= start_node + CHUNK_SIZE as u32 { + break; + } + let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k]; + if insert_neighbor(neighbors, NodeId::new(v), d) { + count += 1; } - }) - }); - - new_new - .chain(new_old) - .flat_map(|(u, v)| { - let d = distance_fn(&data[u as usize], &data[v as usize]); - // Insert both (u, v) and (v, u) so we can parallelize by node - [(u, v, d), (v, u, d)] - }) - }) - .collect(); - - // Sort by first node so candidates for each node are contiguous - candidates.par_sort_unstable_by_key(|(a, _, _)| *a); - candidates.dedup(); - - // Process in parallel by chunks of nodes - const CHUNK_SIZE: usize = 64; - let updates: usize = graph - .neighbors_chunks_mut(CHUNK_SIZE) - .enumerate() - .map(|(i, chunk_neighbors)| { - let start_node = (CHUNK_SIZE * i) as u32; - // Binary search to find the range of candidates for this node - let start = candidates.partition_point(|&(u, _, _)| u < start_node); - let mut count = 0; - for &(u, v, d) in &candidates[start..] { - if u >= start_node + CHUNK_SIZE as u32 { - break; - } - let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k]; - if insert_neighbor(neighbors, NodeId::new(v), d) { - count += 1; } - } - count - }) - .sum(); + count + }) + .sum(); + + total_updates += updates; + } println!( - "NN-Descent iteration {iter}: {updates} updates of {} candidates", - candidates.len(), + "NN-Descent iteration {iter}: {total_updates} updates of {total_candidates} candidates", ); // Early termination if no updates - if updates == 0 { + if total_updates == 0 { break; } } From 4de8ce94183d436aa4fda2cd0dafa16bb59612d7 Mon Sep 17 00:00:00 2001 From: Alexander Neubeck Date: Mon, 19 Jan 2026 16:29:06 +0100 Subject: [PATCH 29/29] Update lib.rs --- crates/famst/src/lib.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/crates/famst/src/lib.rs b/crates/famst/src/lib.rs index b55ae1a..6fe69ee 100644 --- a/crates/famst/src/lib.rs +++ b/crates/famst/src/lib.rs @@ -10,6 +10,29 @@ //! 3. Iterative edge refinement //! //! Generic over data type `T` and distance function. +//! +//! While https://github.com/cmuparlay/ParlayANN is not specifically for MST construction, +//! it gives some idea how far one can scale ANN graph construction. +//! +//! One problem with very large graphs are cache misses. To address them, a common procedure +//! is to partition the data/graph into smaller buckets and solve the k-NN problem within each bucket. +//! To make this actually work, one needs to rearrange the data points according to the partitioning +//! to improve cache locality. A quite strong partitioning can be obtained via masked sorting of geo-filters. +//! For large chunks of the sorted data, a partial MST can be computed and then be merged with the previously +//! computed approximate MSTs. This approach is probably more efficient than the windowing approach currently +//! implemented in Blackbird, since many different masks are needed to get accurate results. +//! +//! The default sorting implementation of Rayon is not sufficient for very large data sets though. +//! The problem is that our data points are pretty large and copying them is costly. But leaving +//! them in place results in cache misses. Therefore, one has to group the data into managable chunks whose data +//! is all in one contiguous block of memory and the actual sorting can be done via references. +//! The simplest solution to chunking is to first sample a random subset of the data. Those data points +//! are collected into a contiguous block and first sorted via indirection and then the data can be rearranged if needed. +//! In the second phase, all data is partitioned according to the sample points. This can be done efficiently +//! in parallel. The intermediate output of this step is the partition each data point belongs to. +//! At the end of this step, the data points can be copied to their locations within their partitions. +//! In the last phase, each partition can be sorted individually via indirection and again be rearranged at the end if needed. +//! Note, that all those steps are cache friendly and highly parallelisable. mod nn_descent; mod union_find; @@ -802,7 +825,7 @@ mod tests { fn test_large_scale_vs_exact() { use rand::distributions::{Distribution, Uniform}; - const N: usize = 1_000_000; + const N: usize = 10_000_000; const DIM: usize = 10; println!("Generating {} random {}-dimensional points...", N, DIM);